Skip to content

Commit

Permalink
0.8.26 update ts cache (waditu#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
zengbin93 committed May 9, 2022
1 parent 54719a5 commit 0dfd61a
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 25 deletions.
34 changes: 20 additions & 14 deletions czsc/data/ts_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,22 @@ def update_bars_return(kline: pd.DataFrame, bar_numbers=None):

class TsDataCache:
"""Tushare 数据缓存"""
def __init__(self, data_path, sdt, edt):
def __init__(self, data_path, refresh=False, sdt="20120101", edt=datetime.now()):
"""
:param data_path: 数据路径
:param refresh: 是否刷新缓存
:param sdt: 缓存开始时间
:param edt: 缓存结束时间
"""
self.date_fmt = "%Y%m%d"
self.verbose = envs.get_verbose()
self.refresh = refresh
self.sdt = pd.to_datetime(sdt).strftime(self.date_fmt)
self.edt = pd.to_datetime(edt).strftime(self.date_fmt)
self.data_path = data_path
self.prefix = "TS_CACHE"
self.name = f"{self.prefix}_{self.sdt}_{self.edt}"
self.cache_path = os.path.join(self.data_path, self.name)
self.cache_path = os.path.join(self.data_path, self.prefix)
os.makedirs(self.cache_path, exist_ok=True)
self.pro = pro
self.__prepare_api_path()
Expand Down Expand Up @@ -105,31 +106,36 @@ def clear(self):
print(f"Tushare 数据缓存清理失败,请手动删除缓存文件夹:{self.cache_path}")

# ------------------------------------Tushare 原生接口----------------------------------------------
def ths_daily(self, ts_code, start_date, end_date, raw_bar=True):
def ths_daily(self, ts_code, start_date=None, end_date=None, raw_bar=True):
"""获取同花顺概念板块的日线行情"""
cache_path = self.api_path_map['ths_daily']
file_cache = os.path.join(cache_path, f"ths_daily_{ts_code}.pkl")
if os.path.exists(file_cache):
kline = io.read_pkl(file_cache)
file_cache = os.path.join(cache_path, f"ths_daily_{ts_code}_sdt{self.sdt}.feather")

if not self.refresh and os.path.exists(file_cache):
kline = pd.read_feather(file_cache)
if self.verbose:
print(f"ths_daily: read cache {file_cache}")
else:
if self.verbose:
print(f"ths_daily: refresh {file_cache}")
kline = pro.ths_daily(ts_code=ts_code, start_date=self.sdt, end_date=self.edt,
fields='ts_code,trade_date,open,close,high,low,vol')
kline = kline.sort_values('trade_date', ignore_index=True)
kline['trade_date'] = pd.to_datetime(kline['trade_date'], format=self.date_fmt)
kline['dt'] = kline['trade_date']
update_bars_return(kline)
io.save_pkl(kline, file_cache)
kline.to_feather(file_cache)

kline['trade_date'] = pd.to_datetime(kline['trade_date'], format=self.date_fmt)
start_date = pd.to_datetime(start_date)
end_date = pd.to_datetime(end_date)
bars = kline[(kline['trade_date'] >= start_date) & (kline['trade_date'] <= end_date)]
bars.reset_index(drop=True, inplace=True)
if start_date:
kline = kline[kline['trade_date'] >= pd.to_datetime(start_date)]
if end_date:
kline = kline[kline['trade_date'] <= pd.to_datetime(end_date)]

kline.reset_index(drop=True, inplace=True)
if raw_bar:
bars = format_kline(bars, freq=Freq.D)
return bars
kline = format_kline(kline, freq=Freq.D)
return kline

def ths_index(self, exchange="A", type_="N"):
"""获取同花顺概念
Expand Down
33 changes: 22 additions & 11 deletions examples/test_offline/test_ts_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,39 @@
os.environ['czsc_verbose'] = '1'


def test_ts_cache_ths_daily():
dc = TsDataCache(data_path='.', sdt='20200101')
df = dc.ths_daily('885573.TI', raw_bar=False)
assert df.shape[0] == 565 and df.shape[1] == 34
df = dc.ths_daily('885573.TI', end_date="20220420", raw_bar=False)
assert df.shape[0] == 556 and df.shape[1] == 34

# 测试被动刷新数据
dc = TsDataCache(data_path='.', sdt='20210101')
df = dc.ths_daily('885573.TI', raw_bar=False)
assert df.shape[0] == 322 and df.shape[1] == 34

# 测试主动刷新数据
dc = TsDataCache(data_path='.', refresh=True, sdt='20210101')
df = dc.ths_daily('885573.TI', raw_bar=False)
assert df.shape[0] == 322 and df.shape[1] == 34

dc.clear()


def test_ts_cache_daily_basic_new():
dc = TsDataCache(data_path='.', sdt='20200101', edt='20211024')
cache_path = './TS_CACHE_20200101_20211024'
assert os.path.exists(cache_path)
df = dc.daily_basic_new(trade_date='2018-03-15')
assert df.shape[0] == 3237 and df.shape[1] == 37

dfb = dc.stocks_daily_basic_new(sdt='20211001', edt='20211020')
assert dfb.shape[1] == df.shape[1] and len(dfb) == 40407
assert dfb.shape[1] == df.shape[1] + 1 and len(dfb) == 40407
dc.clear()


def test_ts_cache_bars():
"""测试获取K线"""
dc = TsDataCache(data_path='.', sdt='20200101', edt='20211024')
cache_path = './TS_CACHE_20200101_20211024'
assert os.path.exists(cache_path)

# 测试日线以上数据获取
bars = dc.pro_bar(ts_code='000001.SZ', asset='E', freq='D',
start_date='20200101', end_date='20211024', raw_bar=True)
Expand Down Expand Up @@ -71,9 +86,6 @@ def test_ts_cache_bars():

def test_ts_cache():
dc = TsDataCache(data_path='.', sdt='20200101', edt='20211024')
cache_path = './TS_CACHE_20200101_20211024'
assert os.path.exists(cache_path)

assert dc.get_next_trade_dates('2022-03-02', 2, 5) == ['20220304', '20220307', '20220308']
assert dc.get_next_trade_dates('2022-03-02', -1, -4) == ['20220224', '20220225', '20220228']
assert dc.get_dates_span('20220224', '20220228') == ['20220224', '20220225', '20220228']
Expand Down Expand Up @@ -124,5 +136,4 @@ def test_ts_cache():
bars = dc.ths_daily(ts_code='885566.TI', start_date='20200101', end_date='20211024', raw_bar=False)
assert len(bars) == 436

dc.clear()
assert not os.path.exists(cache_path)
dc.clear()

0 comments on commit 0dfd61a

Please sign in to comment.