Skip to content

Commit

Permalink
support for parquet in signal files
Browse files Browse the repository at this point in the history
  • Loading branch information
asavinov committed Mar 16, 2024
1 parent befa2c2 commit 400946d
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 10 deletions.
2 changes: 1 addition & 1 deletion common/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def train_lc(df_X, df_y, model_config: dict):
#
args = model_config.get("params").copy()
args["n_jobs"] = -1
args["verbose"] = 1
args["verbose"] = 0
model = LogisticRegression(**args)

#
Expand Down
8 changes: 4 additions & 4 deletions common/gen_labels_highlow.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def _first_location_of_crossing_threshold(df, horizon, threshold, close_column_n
If the location (index) is 0 then it is the next point. If location (index) is NaN,
then the price does not cross the specified threshold during the horizon
(or there is not enough data, e.g., at the end of the series). Therefore, this
function can be used find whether the price will cross the threshold at all
function can be used to find whether the price will cross the threshold at all
during the specified horizon.
The function is somewhat similar to the tsfresh function first_location_of_maximum
Expand All @@ -192,7 +192,7 @@ def fn_high(x):
return np.nan
p = x[0, 0] # Reference price
p_threshold = p*(1+(threshold/100.0)) # Cross line
idx = np.argmax(x[1:, 1] > p_threshold) # First index where price crossed the threshold
idx = np.argmax(x[1:, 1] > p_threshold) # First index where price crosses the threshold

# If all False, then index is 0 (first element of constant series) and we are not able to distinguish it from first element being True
# If index is 0 and first element False (under threshold) then NaN (not exceeds)
Expand All @@ -205,7 +205,7 @@ def fn_low(x):
return np.nan
p = x[0, 0] # Reference price
p_threshold = p*(1+(threshold/100.0)) # Cross line
idx = np.argmax(x[1:, 1] < p_threshold) # First index where price crossed the threshold
idx = np.argmax(x[1:, 1] < p_threshold) # First index where price crosses the threshold

# If all False, then index is 0 (first element of constant series) and we are not able to distinguish it from first element being True
# If index is 0 and first element False (under threshold) then NaN (not exceeds)
Expand Down Expand Up @@ -257,7 +257,7 @@ def is_high_true(x):
elif np.isnan(x[1]):
return True
else:
return x[0] < x[1] # If the first cross point is closer to this point than the second one
return x[0] <= x[1] # If the first cross point is closer to this point than the second one

df[out_column] = df[["first_idx_column", "second_idx_column"]].apply(is_high_true, raw=True, axis=1)

Expand Down
11 changes: 9 additions & 2 deletions scripts/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,15 @@ def main(config_file):
#
out_path = data_path / App.config.get("signal_file_name")

print(f"Storing output file...")
out_df.to_csv(out_path.with_suffix(".csv"), index=False, float_format='%.4f')
print(f"Storing signals with {len(out_df)} records and {len(out_df.columns)} columns in output file {out_path}...")
if out_path.suffix == ".parquet":
out_df.to_parquet(out_path, index=False)
elif out_path.suffix == ".csv":
out_df.to_csv(out_path, index=False, float_format='%.6f')
else:
print(f"ERROR: Unknown extension of the 'signal_file_name' file '{out_path.suffix}'. Only 'csv' and 'parquet' are supported")
return

print(f"Signals stored in file: {out_path}. Length: {len(out_df)}. Columns: {len(out_df.columns)}")

elapsed = datetime.now() - now
Expand Down
11 changes: 9 additions & 2 deletions scripts/train_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,20 @@ def main(config_file):
#
# Load data with (rolling) label point-wise predictions and signals generated
#
file_path = (data_path / App.config.get("signal_file_name")).with_suffix(".csv")
file_path = data_path / App.config.get("signal_file_name")
if not file_path.exists():
print(f"ERROR: Input file does not exist: {file_path}")
return

print(f"Loading signals from input file: {file_path}")
df = pd.read_csv(file_path, parse_dates=[time_column], date_format="ISO8601", nrows=P.in_nrows)
if file_path.suffix == ".parquet":
df = pd.read_parquet(file_path)
elif file_path.suffix == ".csv":
df = pd.read_csv(file_path, parse_dates=[time_column], date_format="ISO8601", nrows=P.in_nrows)
else:
print(f"ERROR: Unknown extension of the 'signal_file_name' file '{file_path.suffix}'. Only 'csv' and 'parquet' are supported")
return

print(f"Signals loaded. Length: {len(df)}. Width: {len(df.columns)}")

#
Expand Down
2 changes: 1 addition & 1 deletion service/App.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class App:
"feature_file_name": "features.csv",
"matrix_file_name": "matrix.csv",
"predict_file_name": "predictions.csv", # predict, predict-rolling
"signal_file_name": "signals",
"signal_file_name": "signals.csv",
"signal_models_file_name": "signal_models",

"model_folder": "MODELS",
Expand Down

0 comments on commit 400946d

Please sign in to comment.