Skip to content

Commit

Permalink
add keep type in target selector
Browse files Browse the repository at this point in the history
  • Loading branch information
foolcage committed Oct 22, 2021
1 parent fd9c2b3 commit fa53802
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 49 deletions.
15 changes: 11 additions & 4 deletions zvt/contract/factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import time
from typing import List, Union, Optional, Type

import numpy as np
import pandas as pd
from sqlalchemy import Column, String, Text
from sqlalchemy.orm import declarative_base
Expand All @@ -16,7 +15,7 @@
from zvt.contract.reader import DataReader, DataListener
from zvt.contract.register import register_schema
from zvt.contract.zvt_context import factor_cls_registry
from zvt.utils.pd_utils import pd_is_not_null, drop_continue_duplicate
from zvt.utils.pd_utils import pd_is_not_null


class Indicator(object):
Expand Down Expand Up @@ -56,7 +55,7 @@ def transform(self, input_df: pd.DataFrame) -> pd.DataFrame:
else:
return g.apply(lambda x: self.transform_one(x.index[0][0], x.reset_index(level=0, drop=True)))

def transform_one(self, entity_id, df: pd.DataFrame) -> pd.DataFrame:
def transform_one(self, entity_id: str, df: pd.DataFrame) -> pd.DataFrame:
"""
df format:
Expand All @@ -67,6 +66,7 @@ def transform_one(self, entity_id, df: pd.DataFrame) -> pd.DataFrame:
the return result would change the columns and keep the format
:param entity_id:
:param df:
:return:
"""
Expand Down Expand Up @@ -471,11 +471,18 @@ def order_type_flag(order_type):
if not order_type:
return 'S'

def order_type_color(order_type):
if order_type:
return "#ec0000"
else:
return "#00da3c"

if pd_is_not_null(self.result_df):
annotation_df = self.result_df.copy()
annotation_df = annotation_df[annotation_df['score']]
annotation_df = annotation_df[~annotation_df['score'].isna()]
annotation_df['value'] = self.factor_df.loc[annotation_df.index]['close']
annotation_df['flag'] = annotation_df['score'].apply(lambda x: order_type_flag(x))
annotation_df['color'] = annotation_df['score'].apply(lambda x: order_type_color(x))
return annotation_df

def fill_gap(self):
Expand Down
35 changes: 16 additions & 19 deletions zvt/factors/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from zvt.utils.pd_utils import normal_index_df


def ma(s: pd.Series, window: int = 5):
def ma(s: pd.Series, window: int = 5) -> pd.Series:
"""
:param s:
Expand All @@ -16,7 +16,7 @@ def ma(s: pd.Series, window: int = 5):
return s.rolling(window=window, min_periods=window).mean()


def ema(s, window=12):
def ema(s: pd.Series, window: int = 12) -> pd.Series:
return s.ewm(span=window, adjust=False, min_periods=window).mean()


Expand All @@ -27,30 +27,25 @@ def live_or_dead(x):
return -1


def volume_up(s: pd.Series, window: int = 60):
"""
:param s:
:param window:
:return:
"""
ma_vol = s.rolling(window=window, min_periods=window).mean()
return s > ma_vol


def macd(s, slow=26, fast=12, n=9, return_type='df', normal=False, count_live_dead=False):
def macd(s: pd.Series,
slow: int = 26,
fast: int = 12,
n: int = 9,
return_type: str = 'df',
normal: bool = False,
count_live_dead: bool = False):
# 短期均线
ema_fast = ema(s, window=fast)
# 长期均线
ema_slow = ema(s, window=slow)

# 短期均线 - 长期均线 = 趋势的力度
diff = ema_fast - ema_slow
diff: pd.Series = ema_fast - ema_slow
# 力度均线
dea = diff.ewm(span=n, adjust=False).mean()
dea: pd.Series = diff.ewm(span=n, adjust=False).mean()

# 力度 的变化
m = (diff - dea) * 2
m: pd.Series = (diff - dea) * 2

# normal it
if normal:
Expand Down Expand Up @@ -150,8 +145,10 @@ def score(self, input_df) -> pd.DataFrame:


class MaTransformer(Transformer):
def __init__(self, windows=[5, 10], cal_change_pct=False) -> None:
def __init__(self, windows=None, cal_change_pct=False) -> None:
super().__init__()
if windows is None:
windows = [5, 10]
self.windows = windows
self.cal_change_pct = cal_change_pct

Expand Down Expand Up @@ -321,6 +318,6 @@ def calculate_score(df, factor_name, quantile):


# the __all__ is generated
__all__ = ['ma', 'ema', 'live_or_dead', 'volume_up', 'macd', 'point_in_range',
__all__ = ['ma', 'ema', 'live_or_dead', 'macd', 'point_in_range',
'intersect_ranges', 'combine', 'distance', 'intersect', 'RankScorer', 'MaTransformer',
'IntersectTransformer', 'MaAndVolumeTransformer', 'MacdTransformer', 'QuantileScorer']
8 changes: 5 additions & 3 deletions zvt/factors/ma/ma_factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,11 @@ def compute_result(self):
self.factor_df['close'] < 1.1 * self.factor_df[cols[0]])
for col in cols[1:]:
if self.over_mode == 'and':
filter_se = filter_se & (self.factor_df['close'] > self.factor_df[col])
filter_se = filter_se & ((self.factor_df['close'] > self.factor_df[col]) & (
self.factor_df['close'] < 1.1 * self.factor_df[col]))
else:
filter_se = filter_se | (self.factor_df['close'] > self.factor_df[col])
filter_se = filter_se | ((self.factor_df['close'] > self.factor_df[col]) & (
self.factor_df['close'] < 1.1 * self.factor_df[col]))
# 放量
if self.vol_windows:
vol_cols = [f'vol_ma{window}' for window in self.vol_windows]
Expand Down Expand Up @@ -200,4 +202,4 @@ def compute_result(self):
end_timestamp=now_pd_timestamp(), level=level, need_persist=False)
print(factor.result_df)
# the __all__ is generated
__all__ = ['get_ma_factor_schema', 'MaFactor', 'CrossMaFactor', 'VolumeUpMaFactor', 'CrossMaVolumeFactor']
__all__ = ['get_ma_factor_schema', 'MaFactor', 'CrossMaFactor', 'VolumeUpMaFactor', 'CrossMaVolumeFactor']
6 changes: 5 additions & 1 deletion zvt/factors/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,5 +232,9 @@ def decode_rect(dct):

def decode_fenxing(dct):
return Fenxing(state=dct['state'], kdata=dct['kdata'], index=dct['index'])


# the __all__ is generated
__all__ = ['Direction', 'Fenxing', 'fenxing_power', 'a_include_b', 'get_direction', 'is_up', 'is_down', 'handle_first_fenxing', 'handle_zhongshu', 'handle_duan', 'handle_including', 'FactorStateEncoder', 'decode_rect', 'decode_fenxing']
__all__ = ['Direction', 'Fenxing', 'fenxing_power', 'a_include_b', 'get_direction', 'is_up', 'is_down',
'handle_first_fenxing', 'handle_zhongshu', 'handle_duan', 'handle_including', 'FactorStateEncoder',
'decode_rect', 'decode_fenxing']
59 changes: 37 additions & 22 deletions zvt/factors/target_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class TargetType(Enum):
open_long = 'open_long'
# open_short 代表开空,并应该平掉相应标的的多单
open_short = 'open_short'
# 其他情况就是保持当前的持仓
keep = 'keep'


class SelectMode(Enum):
Expand Down Expand Up @@ -56,6 +56,7 @@ def __init__(self, entity_ids=None, entity_schema=Stock, exchanges=None, codes=N

self.open_long_df: DataFrame = None
self.open_short_df: DataFrame = None
self.keep_df: DataFrame = None

self.init_factors(entity_ids=entity_ids, entity_schema=entity_schema, exchanges=exchanges, codes=codes,
start_timestamp=start_timestamp, end_timestamp=end_timestamp, level=self.level)
Expand Down Expand Up @@ -91,7 +92,7 @@ def run(self):
"""
if self.filter_factors:
musts = []
filters = []
for factor in self.filter_factors:
df = factor.result_df

Expand All @@ -101,15 +102,15 @@ def run(self):
if len(df.columns) > 1:
s = df.agg("and", axis="columns")
s.name = 'score'
musts.append(s.to_frame(name='score'))
filters.append(s.to_frame(name='score'))
else:
df.columns = ['score']
musts.append(df)
filters.append(df)

if self.select_mode == SelectMode.condition_and:
self.filter_result = list(accumulate(musts, func=operator.__and__))[-1]
self.filter_result = list(accumulate(filters, func=operator.__and__))[-1]
else:
self.filter_result = list(accumulate(musts, func=operator.__or__))[-1]
self.filter_result = list(accumulate(filters, func=operator.__or__))[-1]
if self.score_factors:
scores = []
for factor in self.score_factors:
Expand Down Expand Up @@ -148,22 +149,36 @@ def get_open_short_targets(self, timestamp):

# overwrite it to generate targets
def generate_targets(self):
if pd_is_not_null(self.filter_result) and pd_is_not_null(self.score_result):
# for long
result1 = self.filter_result[self.filter_result.score]
result2 = self.score_result[self.score_result.score >= self.long_threshold]
long_result = result2.loc[result1.index, :]
# for short
result1 = self.filter_result[~self.filter_result.score]
result2 = self.score_result[self.score_result.score <= self.short_threshold]
short_result = result2.loc[result1.index, :]
elif pd_is_not_null(self.score_result):
long_result = self.score_result[self.score_result.score >= self.long_threshold]
short_result = self.score_result[self.score_result.score <= self.short_threshold]
else:
long_result = self.filter_result[self.filter_result.score == True]
short_result = self.filter_result[self.filter_result.score == False]
keep_result = pd.DataFrame()
long_result = pd.DataFrame()
short_result = pd.DataFrame()

if pd_is_not_null(self.filter_result):
keep_result = self.filter_result[self.filter_result['score'].isna()]
long_result = self.filter_result[self.filter_result['score'] == True]
short_result = self.filter_result[self.filter_result['score'] == False]

if pd_is_not_null(self.score_result):
score_keep_result = self.score_result[(self.score_result['score'] > self.short_threshold) & (
self.score_result['score'] < self.long_threshold)]
if pd_is_not_null(keep_result):
keep_result = score_keep_result.loc[keep_result.index, :]
else:
keep_result = score_keep_result

score_long_result = self.score_result[self.score_result['score'] >= self.long_threshold]
if pd_is_not_null(long_result):
long_result = score_long_result.loc[long_result.index, :]
else:
long_result = score_long_result

score_short_result = self.score_result[self.score_result['score'] <= self.short_threshold]
if pd_is_not_null(short_result):
short_result = score_short_result.loc[short_result.index, :]
else:
short_result = score_short_result

self.keep_df = self.normalize_result_df(keep_result)
self.open_long_df = self.normalize_result_df(long_result)
self.open_short_df = self.normalize_result_df(short_result)

Expand Down Expand Up @@ -203,4 +218,4 @@ def draw(self,


# the __all__ is generated
__all__ = ['TargetType', 'SelectMode', 'TargetSelector']
__all__ = ['TargetType', 'SelectMode', 'TargetSelector']

0 comments on commit fa53802

Please sign in to comment.