Skip to content

Commit

Permalink
refactor trader
Browse files Browse the repository at this point in the history
  • Loading branch information
foolcage committed Apr 4, 2023
1 parent 5351207 commit a5551a2
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 119 deletions.
12 changes: 6 additions & 6 deletions examples/trader/follow_ii_trader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ def on_time(self, timestamp: pd.Timestamp):
long_df = df[df["change_ratio"] > 0.05]
short_df = df[df["change_ratio"] < -0.5]
try:
self.trade_the_targets(
due_timestamp=timestamp,
happen_timestamp=timestamp,
long_selected=set(long_df["entity_id"].to_list()),
short_selected=set(short_df["entity_id"].to_list()),
)
long_targets = set(long_df["entity_id"].to_list())
short_targets = set(short_df["entity_id"].to_list())
if long_targets:
self.buy(timestamp=timestamp, entity_ids=long_targets)
if short_targets:
self.sell(timestamp=timestamp, entity_ids=short_targets)
except Exception as e:
self.logger.error(e)

Expand Down
4 changes: 2 additions & 2 deletions examples/trader/macd_day_trader.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ def short_position_control(self):
# 空头仓位管理
return super().short_position_control()

def on_targets_filtered(
def on_factor_targets_filtered(
self, timestamp, level, factor: Factor, long_targets: List[str], short_targets: List[str]
) -> Tuple[List[str], List[str]]:
# 过滤某级别选出的 标的
return super().on_targets_filtered(timestamp, level, factor, long_targets, short_targets)
return super().on_factor_targets_filtered(timestamp, level, factor, long_targets, short_targets)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion src/zvt/contract/factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def fill_gap(self):
self.result_df = self.result_df.reindex(new_index)
self.result_df = self.result_df.groupby(level=0).fillna(method=self.fill_method, limit=self.effective_number)

def update_entities(self, entity_ids):
def add_entities(self, entity_ids):
if (self.entity_ids and entity_ids) and (set(self.entity_ids) == set(entity_ids)):
self.logger.info(f"current: {self.entity_ids}")
self.logger.info(f"refresh: {entity_ids}")
Expand Down
3 changes: 1 addition & 2 deletions src/zvt/trader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,14 @@ def __init__(
position_pct: float = 0,
order_money: float = 0,
):
self.trading_signal_type = trading_signal_type
self.entity_id = entity_id
self.due_timestamp = due_timestamp
self.happen_timestamp = happen_timestamp
self.trading_level = trading_level
self.trading_signal_type = trading_signal_type

# use position_pct or order_money
self.position_pct = position_pct

# when close the position,just use position_pct
self.order_money = order_money

Expand Down
Loading

0 comments on commit a5551a2

Please sign in to comment.