Skip to content

Commit

Permalink
support keep_history trader
Browse files Browse the repository at this point in the history
  • Loading branch information
foolcage committed Jan 27, 2021
1 parent 25b04ae commit f746a4f
Show file tree
Hide file tree
Showing 14 changed files with 302 additions and 158 deletions.
16 changes: 8 additions & 8 deletions examples/recorders/joinquant_fund_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@
sched = BackgroundScheduler()


# 周4抓取
@sched.scheduled_job('cron', hour=19, minute=00, day_of_week=3)
# 周6抓取
@sched.scheduled_job('cron', hour=10, minute=00, day_of_week=5)
def record_fund():
while True:
email_action = EmailInformer()

try:
# 基金和基金持仓数据
Fund.record_data(provider='joinquant', sleeping_time=1)
FundStock.record_data(provider='joinquant', sleeping_time=1)
# Fund.record_data(provider='joinquant', sleeping_time=1)
# FundStock.record_data(provider='joinquant', sleeping_time=1)
# 股票周线后复权数据
Stock1wkHfqKdata.record_data(provider='joinquant', sleeping_time=0)

Expand All @@ -36,8 +36,8 @@ def record_fund():
time.sleep(60)


# 周2抓取
@sched.scheduled_job('cron', hour=19, minute=00, day_of_week=1)
# 周6抓取
@sched.scheduled_job('cron', hour=13, minute=00, day_of_week=6)
def record_valuation():
while True:
email_action = EmailInformer()
Expand All @@ -58,10 +58,10 @@ def record_valuation():
if __name__ == '__main__':
init_log('joinquant_fund_runner.log')

record_valuation()

record_fund()

# record_valuation()

sched.start()

sched._thread.join()
74 changes: 74 additions & 0 deletions examples/trader/keep_run_trader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
import logging

from zvt.api import get_top_volume_entities
from zvt.api.stats import get_top_fund_holding_stocks
from zvt.api.trader_info_api import clear_trader
from zvt.contract import IntervalLevel
from zvt.factors import TargetSelector, GoldCrossFactor, BullFactor
from zvt.trader import StockTrader
from zvt.utils.time_utils import split_time_interval, next_date

logger = logging.getLogger(__name__)


class MultipleLevelTrader(StockTrader):
def init_selectors(self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp,
adjust_type=None):
start_timestamp = next_date(start_timestamp, -50)

# 周线策略
week_selector = TargetSelector(entity_ids=entity_ids, entity_schema=entity_schema, exchanges=exchanges,
codes=codes, start_timestamp=next_date(start_timestamp, -200),
end_timestamp=end_timestamp,
provider='joinquant', level=IntervalLevel.LEVEL_1WEEK, long_threshold=0.7)
week_bull_factor = BullFactor(entity_ids=entity_ids, entity_schema=entity_schema,
exchanges=exchanges,
codes=codes, start_timestamp=next_date(start_timestamp, -200),
end_timestamp=end_timestamp,
provider='joinquant', level=IntervalLevel.LEVEL_1WEEK)
week_selector.add_filter_factor(week_bull_factor)

# 日线策略
day_selector = TargetSelector(entity_ids=entity_ids, entity_schema=entity_schema, exchanges=exchanges,
codes=codes, start_timestamp=start_timestamp, end_timestamp=end_timestamp,
provider='joinquant', level=IntervalLevel.LEVEL_1DAY, long_threshold=0.7)
day_gold_cross_factor = GoldCrossFactor(entity_ids=entity_ids, entity_schema=entity_schema, exchanges=exchanges,
codes=codes, start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
provider='joinquant', level=IntervalLevel.LEVEL_1DAY)
day_selector.add_filter_factor(day_gold_cross_factor)

# 同时使用日线,周线级别
self.selectors.append(day_selector)
self.selectors.append(week_selector)


if __name__ == '__main__':
start = '2019-01-01'
end = '2021-01-01'
trader_name = 'keep_run_trader'
clear_trader(trader_name=trader_name)
for time_interval in split_time_interval(start=start, end=end, interval=40):
start_timestamp = time_interval[0]
end_timestamp = time_interval[-1]
# 成交量
vol_df = get_top_volume_entities(entity_type='stock',
start_timestamp=next_date(start_timestamp, -50),
end_timestamp=start_timestamp,
pct=0.3)
# 机构重仓
ii_df = get_top_fund_holding_stocks(timestamp=start_timestamp, pct=0.3, by='trading')

current_entity_pool = list(set(vol_df.index.tolist()) & set(ii_df.index.tolist()))

logger.info(f'current_entity_pool({len(current_entity_pool)}):{current_entity_pool}')

trader = MultipleLevelTrader(start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
entity_ids=current_entity_pool,
trader_name=trader_name,
keep_history=True,
draw_result=False,
rich_mode=False)
trader.run()
14 changes: 8 additions & 6 deletions examples/trader/macd_day_trader.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
# -*- coding: utf-8 -*-
from typing import List
from typing import List, Tuple

import pandas as pd

from zvt.contract import IntervalLevel
from zvt.factors import TargetSelector, GoldCrossFactor
from zvt.trader import TradingSignal
from zvt.trader.trader import StockTrader


# 依赖数据
# data_schema: Stock1dHfqKdata
# provider: joinquant
from zvt.utils import next_date


class MacdDayTrader(StockTrader):

def init_selectors(self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp,
adjust_type=None):
# 日线策略
start_timestamp = next_date(start_timestamp, -50)
day_selector = TargetSelector(entity_ids=entity_ids, entity_schema=entity_schema, exchanges=exchanges,
codes=codes, start_timestamp=start_timestamp, end_timestamp=end_timestamp,
provider='joinquant', level=IntervalLevel.LEVEL_1DAY, long_threshold=0.7)
Expand Down Expand Up @@ -68,10 +70,10 @@ def short_position_control(self):
# 空头仓位管理
return super().short_position_control()

def on_targets_selected(self, timestamp, level, selector: TargetSelector, long_targets: List[str],
short_targets: List[str]) -> List[str]:
def on_targets_filtered(self, timestamp, level, selector: TargetSelector, long_targets: List[str],
short_targets: List[str]) -> Tuple[List[str], List[str]]:
# 过滤某级别选出的 标的
return super().on_targets_selected(timestamp, level, selector, long_targets, short_targets)
return super().on_targets_filtered(timestamp, level, selector, long_targets, short_targets)


if __name__ == '__main__':
Expand Down
40 changes: 39 additions & 1 deletion tests/utils/test_time_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
from zvt.contract import IntervalLevel
from zvt.utils.time_utils import evaluate_size_from_timestamp, next_timestamp, to_pd_timestamp, \
is_finished_kdata_timestamp
is_finished_kdata_timestamp, split_time_interval, is_same_date


def test_evaluate_size_from_timestamp():
Expand Down Expand Up @@ -49,3 +49,41 @@ def test_is_finished_kdata_timestamp():

timestamp = '2019-01-10'
assert is_finished_kdata_timestamp(timestamp, level=IntervalLevel.LEVEL_1DAY)


def test_split_time_interval():
first = None
last = None
start = '2020-01-01'
end = '2021-01-01'
for interval in split_time_interval(start, end, interval=30):
if first is None:
first = interval
last = interval

print(first)
print(last)

assert is_same_date(first[0], start)
assert is_same_date(first[-1], '2020-01-31')

assert is_same_date(last[-1], end)

def test_split_time_interval_month():
first = None
last = None
start = '2020-01-01'
end = '2021-01-01'
for interval in split_time_interval(start, end, method='month'):
if first is None:
first = interval
last = interval

print(first)
print(last)

assert is_same_date(first[0], start)
assert is_same_date(first[-1], '2020-01-31')

assert is_same_date(last[0], '2021-01-01')
assert is_same_date(last[-1], '2021-01-01')
11 changes: 10 additions & 1 deletion zvt/api/trader_info_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,19 @@
from zvt.contract.api import get_data, get_db_session
from zvt.contract.normal_data import NormalData
from zvt.contract.reader import DataReader
from zvt.domain import AccountStats, Order, trader_info
from zvt.domain import AccountStats, Order, trader_info, TraderInfo, Position
from zvt.contract.drawer import Drawer


def clear_trader(trader_name, session=None):
if not session:
session = get_db_session('zvt', data_schema=TraderInfo)
session.query(TraderInfo).filter(TraderInfo.trader_name == trader_name).delete()
session.query(AccountStats).filter(AccountStats.trader_name == trader_name).delete()
session.query(Position).filter(Position.trader_name == trader_name).delete()
session.query(Order).filter(Order.trader_name == trader_name).delete()
session.commit()

def get_trader_info(trader_name=None, return_type='df', start_timestamp=None, end_timestamp=None,
filters=None, session=None, order=None, limit=None) -> List[trader_info.TraderInfo]:
if trader_name:
Expand Down
6 changes: 3 additions & 3 deletions zvt/contract/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,8 @@ def __init__(self,
def get_latest_saved_record(self, entity):
order = eval('self.data_schema.{}.desc()'.format(self.get_evaluated_time_field()))

# 对于k线这种数据,最后一个记录有可能是没完成的,所以取两个,总是删掉最后一个数据,更新之
# 对于k线这种数据,最后一个记录有可能是没完成的,所以取两个
# 同一周期内只保留最新的一个数据
records = get_data(entity_id=entity.id,
provider=self.provider,
data_schema=self.data_schema,
Expand All @@ -533,9 +534,8 @@ def get_latest_saved_record(self, entity):
# delete unfinished kdata
if len(records) == 2:
if is_in_same_interval(t1=records[0].timestamp, t2=records[1].timestamp, level=self.level):
self.session.delete(records[0])
self.session.delete(records[1])
self.session.flush()
return records[1]
return records[0]
return None

Expand Down
5 changes: 4 additions & 1 deletion zvt/contract/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,11 @@ def get_interval_timestamps(cls, start_date, end_date, level: IntervalLevel):
"""

for current_date in cls.get_trading_dates(start_date=start_date, end_date=end_date):
if level >= IntervalLevel.LEVEL_1DAY:
if level == IntervalLevel.LEVEL_1DAY:
yield current_date
elif level == IntervalLevel.LEVEL_1WEEK:
if current_date.weekday() == 4:
yield current_date
else:
start_end_list = cls.get_trading_intervals()

Expand Down
3 changes: 0 additions & 3 deletions zvt/domain/trader_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ class TraderInfo(TraderBase, Mixin):
# 机器人名字
trader_name = Column(String(length=128))

entity_ids = Column(String(length=1024))
entity_type = Column(String(length=128))
exchanges = Column(String(length=128))
codes = Column(String(length=128))
start_timestamp = Column(DateTime)
end_timestamp = Column(DateTime)
provider = Column(String(length=32))
Expand Down
9 changes: 5 additions & 4 deletions zvt/factors/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,11 @@ def consecutive_count(input_df, col, pattern=[-5, 1]):
count = -1
negative = count

if (count >= pattern[1]) and (negative <= pattern[0]):
input_df.loc[index, 'score'] = True
else:
input_df.loc[index, 'score'] = True
if pattern:
if (count >= pattern[1]) and (negative <= pattern[0]):
input_df.loc[index, 'score'] = True
else:
input_df.loc[index, 'score'] = False

# 设置目前状态
input_df.loc[index, 'count'] = count
Expand Down
17 changes: 2 additions & 15 deletions zvt/recorders/joinquant/quotes/jq_stock_kdata_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from zvt.contract import IntervalLevel, AdjustType
from zvt.contract.api import df_to_db
from zvt.contract.recorder import FixedCycleDataRecorder
from zvt.domain import Stock, StockKdataCommon, Stock1dHfqKdata
from zvt.domain import Stock, StockKdataCommon, Stock1dHfqKdata, Stock1wkHfqKdata
from zvt.recorders.joinquant.common import to_jq_trading_level, to_jq_entity_id
from zvt.utils.pd_utils import pd_is_not_null
from zvt.utils.time_utils import to_time_str, now_pd_timestamp, TIME_FORMAT_DAY, TIME_FORMAT_ISO8601
Expand Down Expand Up @@ -139,20 +139,7 @@ def generate_kdata_id(se):


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--level', help='trading level', default='1d', choices=[item.value for item in IntervalLevel])
parser.add_argument('--codes', help='codes', default=['000001'], nargs='+')
Stock1wkHfqKdata.record_data(codes=['300999'])

args = parser.parse_args()

level = IntervalLevel(args.level)
codes = args.codes

init_log('jq_china_stock_{}_kdata.log'.format(args.level))
JqChinaStockKdataRecorder(level=level, sleeping_time=0, codes=codes, real_time=False,
adjust_type=AdjustType.hfq,day_data=True).run()

print(get_kdata(entity_id='stock_sz_000001', limit=10, order=Stock1dHfqKdata.timestamp.desc(),
adjust_type=AdjustType.hfq))
# the __all__ is generated
__all__ = ['JqChinaStockKdataRecorder']
Loading

0 comments on commit f746a4f

Please sign in to comment.