Skip to content

Commit

Permalink
Add tests for SigAnaRecord
Browse files Browse the repository at this point in the history
  • Loading branch information
D-X-Y committed Mar 16, 2021
1 parent 9f57681 commit 6559d44
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
4 changes: 2 additions & 2 deletions qlib/workflow/record_temp.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ class SigAnaRecord(SignalRecord):

artifact_path = "sig_analysis"

def __init__(self, recorder, ana_long_short=False, ann_scaler=252):
super().__init__(recorder=recorder)
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, **kwargs):
super().__init__(recorder=recorder, **kwargs)
self.ana_long_short = ana_long_short
self.ann_scaler = ann_scaler

Expand Down
22 changes: 12 additions & 10 deletions tests/test_all_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def train_with_sigana():
dataset = init_instance_by_config(task["dataset"])

# start exp
with R.start(experiment_name="workflow"):
with R.start(experiment_name="workflow_with_sigana"):
R.log_params(**flatten_dict(task))
model.fit(dataset)

Expand All @@ -163,7 +163,8 @@ def train_with_sigana():
sar.generate()
ic = sar.load(sar.get_path("ic.pkl"))
ric = sar.load(sar.get_path("ric.pkl"))
return pred_score, {"ic": ic, "ric": ric}, rid
uri_path = R.get_uri()
return pred_score, {"ic": ic, "ric": ric}, uri_path


def fake_experiment():
Expand Down Expand Up @@ -222,30 +223,31 @@ class TestAllFlow(TestAutoData):
def tearDownClass(cls) -> None:
shutil.rmtree(str(Path(C["exp_manager"]["kwargs"]["uri"].strip("file:")).resolve()))

def test_0_train(self):
def test_0_train_with_sigana(self):
TestAllFlow.PRED_SCORE, ic_ric, uri_path = train_with_sigana()
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))

def test_1_train(self):
TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train()
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")

def test_1_backtest(self):
def test_2_backtest(self):
analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID)
self.assertGreaterEqual(
analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0],
0.10,
"backtest failed",
)

def test_2_expmanager(self):
def test_3_expmanager(self):
pass_default, pass_current, uri_path = fake_experiment()
self.assertTrue(pass_default, msg="default uri is incorrect")
self.assertTrue(pass_current, msg="current uri is incorrect")
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))

def test_3_train_with_sigana(self):
TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train_with_sigana()
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")


def suite():
_suite = unittest.TestSuite()
Expand Down

0 comments on commit 6559d44

Please sign in to comment.