Skip to content

Commit

Permalink
Supporting shared processor (microsoft#596)
Browse files Browse the repository at this point in the history
* Supporting shared processor

* fix readonly reverse bug

* remove pytests dependency

* with fit bug

* fix parameter error
  • Loading branch information
you-n-g authored Sep 13, 2021
1 parent 28c99c7 commit 51709c2
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 31 deletions.
91 changes: 67 additions & 24 deletions qlib/data/dataset/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,14 @@ class DataHandlerLP(DataHandler):

# process type
PTYPE_I = "independent"
# - self._infer will be processed by infer_processors
# - self._learn will be processed by learn_processors
# - self._infer will be processed by shared_processors + infer_processors
# - self._learn will be processed by shared_processors + learn_processors

# NOTE:
PTYPE_A = "append"
# - self._infer will be processed by infer_processors
# - self._learn will be processed by infer_processors + learn_processors

# - self._infer will be processed by shared_processors + infer_processors
# - self._learn will be processed by shared_processors + infer_processors + learn_processors
# - (e.g. self._infer processed by learn_processors )

def __init__(
Expand All @@ -308,8 +311,9 @@ def __init__(
start_time=None,
end_time=None,
data_loader: Union[dict, str, DataLoader] = None,
infer_processors=[],
learn_processors=[],
infer_processors: List = [],
learn_processors: List = [],
shared_processors: List = [],
process_type=PTYPE_A,
drop_raw=False,
**kwargs,
Expand Down Expand Up @@ -360,7 +364,8 @@ def __init__(
# Setup preprocessor
self.infer_processors = [] # for lint
self.learn_processors = [] # for lint
for pname in "infer_processors", "learn_processors":
self.shared_processors = [] # for lint
for pname in "infer_processors", "learn_processors", "shared_processors":
for proc in locals()[pname]:
getattr(self, pname).append(
init_instance_by_config(
Expand All @@ -375,9 +380,12 @@ def __init__(
super().__init__(instruments, start_time, end_time, data_loader, **kwargs)

def get_all_processors(self):
return self.infer_processors + self.learn_processors
return self.shared_processors + self.infer_processors + self.learn_processors

def fit(self):
"""
fit data without processing the data
"""
for proc in self.get_all_processors():
with TimeInspector.logt(f"{proc.__class__.__name__}"):
proc.fit(self._data)
Expand All @@ -390,45 +398,80 @@ def fit_process_data(self):
"""
self.process_data(with_fit=True)

@staticmethod
def _run_proc_l(
df: pd.DataFrame, proc_l: List[processor_module.Processor], with_fit: bool, check_for_infer: bool
) -> pd.DataFrame:
for proc in proc_l:
if check_for_infer and not proc.is_for_infer():
raise TypeError("Only processors usable for inference can be used in `infer_processors` ")
with TimeInspector.logt(f"{proc.__class__.__name__}"):
if with_fit:
proc.fit(df)
df = proc(df)
return df

@staticmethod
def _is_proc_readonly(proc_l: List[processor_module.Processor]):
"""
NOTE: it will return True if `len(proc_l) == 0`
"""
for p in proc_l:
if not p.readonly():
return False
return True

def process_data(self, with_fit: bool = False):
"""
process_data data. Fun `processor.fit` if necessary
Notation: (data) [processor]
# data processing flow of self.process_type == DataHandlerLP.PTYPE_I
(self._data)-[shared_processors]-(_shared_df)-[learn_processors]-(_learn_df)
\
-[infer_processors]-(_infer_df)
# data processing flow of self.process_type == DataHandlerLP.PTYPE_A
(self._data)-[shared_processors]-(_shared_df)-[infer_processors]-(_infer_df)-[learn_processors]-(_learn_df)
Parameters
----------
with_fit : bool
The input of the `fit` will be the output of the previous processor
"""
# shared data processors
# 1) assign
_shared_df = self._data
if not self._is_proc_readonly(self.shared_processors): # avoid modifying the original data
_shared_df = _shared_df.copy()
# 2) process
_shared_df = self._run_proc_l(_shared_df, self.shared_processors, with_fit=with_fit, check_for_infer=True)

# data for inference
_infer_df = self._data
if len(self.infer_processors) > 0 and not self.drop_raw: # avoid modifying the original data
# 1) assign
_infer_df = _shared_df
if not self._is_proc_readonly(self.infer_processors): # avoid modifying the original data
_infer_df = _infer_df.copy()
# 2) process
_infer_df = self._run_proc_l(_infer_df, self.infer_processors, with_fit=with_fit, check_for_infer=True)

for proc in self.infer_processors:
if not proc.is_for_infer():
raise TypeError("Only processors usable for inference can be used in `infer_processors` ")
with TimeInspector.logt(f"{proc.__class__.__name__}"):
if with_fit:
proc.fit(_infer_df)
_infer_df = proc(_infer_df)
self._infer = _infer_df

# data for learning
# 1) assign
if self.process_type == DataHandlerLP.PTYPE_I:
_learn_df = self._data
elif self.process_type == DataHandlerLP.PTYPE_A:
# based on `infer_df` and append the processor
_learn_df = _infer_df
else:
raise NotImplementedError(f"This type of input is not supported")

if len(self.learn_processors) > 0: # avoid modifying the original data
if not self._is_proc_readonly(self.learn_processors): # avoid modifying the original data
_learn_df = _learn_df.copy()
for proc in self.learn_processors:
with TimeInspector.logt(f"{proc.__class__.__name__}"):
if with_fit:
proc.fit(_learn_df)
_learn_df = proc(_learn_df)
# 2) process
_learn_df = self._run_proc_l(_learn_df, self.learn_processors, with_fit=with_fit, check_for_infer=False)

self._learn = _learn_df

if self.drop_raw:
Expand Down
17 changes: 17 additions & 0 deletions qlib/data/dataset/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ def is_for_infer(self) -> bool:
"""
return True

def readonly(self) -> bool:
"""
Does the processor treat the input data readonly (i.e. does not write the input data) when processsing
Knowning the readonly information is helpful to the Handler to avoid uncessary copy
"""
return False

def config(self, **kwargs):
attr_list = {"fit_start_time", "fit_end_time"}
for k, v in kwargs.items():
Expand All @@ -92,6 +100,9 @@ def __init__(self, fields_group=None):
def __call__(self, df):
return df.dropna(subset=get_group_columns(df, self.fields_group))

def readonly(self):
return True


class DropnaLabel(DropnaProcessor):
def __init__(self, fields_group="label"):
Expand All @@ -113,6 +124,9 @@ def __call__(self, df):
mask = df.columns.isin(self.col_list)
return df.loc[:, ~mask]

def readonly(self):
return True


class FilterCol(Processor):
def __init__(self, fields_group="feature", col_list=[]):
Expand All @@ -128,6 +142,9 @@ def __call__(self, df):
mask = df.columns.get_level_values(-1).isin(self.col_list)
return df.loc[:, mask]

def readonly(self):
return True


class TanhProcess(Processor):
"""Use tanh to process noise data"""
Expand Down
13 changes: 6 additions & 7 deletions tests/storage_tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pathlib import Path
from collections.abc import Iterable

import pytest
import numpy as np
from qlib.tests import TestAutoData

Expand Down Expand Up @@ -33,13 +32,13 @@ def test_calendar_storage(self):
print(f"calendar[-1]: {calendar[-1]}")

calendar = CalendarStorage(freq="1min", future=False, provider_uri="not_found")
with pytest.raises(ValueError):
with self.assertRaises(ValueError):
print(calendar.data)

with pytest.raises(ValueError):
with self.assertRaises(ValueError):
print(calendar[:])

with pytest.raises(ValueError):
with self.assertRaises(ValueError):
print(calendar[0])

def test_instrument_storage(self):
Expand Down Expand Up @@ -90,10 +89,10 @@ def test_instrument_storage(self):
print(f"instrument['SH600000']: {instrument['SH600000']}")

instrument = InstrumentStorage(market="csi300", provider_uri="not_found")
with pytest.raises(ValueError):
with self.assertRaises(ValueError):
print(instrument.data)

with pytest.raises(ValueError):
with self.assertRaises(ValueError):
print(instrument["sSH600000"])

def test_feature_storage(self):
Expand Down Expand Up @@ -152,7 +151,7 @@ def test_feature_storage(self):

feature = FeatureStorage(instrument="SH600004", field="close", freq="day", provider_uri=self.provider_uri)

with pytest.raises(IndexError):
with self.assertRaises(IndexError):
print(feature[0])
assert isinstance(
feature[815][1], (float, np.float32)
Expand Down

0 comments on commit 51709c2

Please sign in to comment.