Skip to content

Commit

Permalink
Add functionality to plot true close vs prediction on TGCN
Browse files Browse the repository at this point in the history
  • Loading branch information
nickkarras committed May 6, 2023
1 parent 70f0e99 commit 7d20733
Show file tree
Hide file tree
Showing 8 changed files with 1,658 additions and 1,279 deletions.
361 changes: 343 additions & 18 deletions Data_preprocessing.ipynb

Large diffs are not rendered by default.

88 changes: 71 additions & 17 deletions TGCN_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@
"from torch_geometric_temporal.nn.recurrent import A3TGCN2, DCRNN, TGCN2, TGCN, GConvGRU, GConvLSTM, GCLSTM\n",
"from torch.nn import Linear\n",
"from torch.nn import ReLU\n",
"import matplotlib.pyplot as plt\n"
"import matplotlib.pyplot as plt\n",
"import matplotlib.dates as mdates"
]
},
{
Expand Down Expand Up @@ -232,7 +233,7 @@
"metadata": {},
"outputs": [],
"source": [
"edge_threshold = 0\n",
"edge_threshold = 0.2\n",
"\n",
"# train_edge_weights[train_edge_weights == 1] = 0 #We exclude self edges\n",
"train_edge_weights[train_edge_weights < edge_threshold] = 0\n",
Expand Down Expand Up @@ -511,10 +512,10 @@
" snapsot[3] = snapsot[3].squeeze(0)\n",
" \n",
" y_hat = model(snapsot[0], snapsot[1][-1], snapsot[2][-1])\n",
" y_true = y_hat\n",
" y_pred = snapsot[3]\n",
" y_true = snapsot[3]\n",
" y_pred = y_hat\n",
" \n",
" y_pred_ar = y_hat.detach().numpy()\n",
" y_pred_ar = y_pred.detach().numpy()\n",
" for y in y_pred_ar:\n",
" predictions.append(y.reshape(num_nodes))\n",
" \n",
Expand Down Expand Up @@ -764,6 +765,49 @@
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "cce38d7f",
"metadata": {},
"source": [
"<font size=\"3\"> \n",
"This function plots the true and predicted values for a given asset over time.\n",
"</font>"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "44b7fdee",
"metadata": {},
"outputs": [],
"source": [
"def plot_precitions_vs_actuals(pivot_true,pivot_pred,asset,metrics_plot_path):\n",
" # Convert datetime objects to matplotlib dates\n",
" dates_true = mdates.date2num(pivot_true.index)\n",
" dates_pred = mdates.date2num(pivot_pred.index)\n",
"\n",
" # Plotting the true values\n",
" plt.plot(dates_true, pivot_true[asset], label='True Values')\n",
" # Plotting the predicted values\n",
" plt.plot(dates_pred, pivot_pred[asset], label='Predictions')\n",
"\n",
" # Formatting the x-axis dates\n",
" plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%m/%d/%Y'))\n",
" plt.gca().xaxis.set_major_locator(mdates.DayLocator(interval=20)) # set tick interval to 30 days\n",
"\n",
" # Adding labels and title\n",
" plt.xlabel('Date')\n",
" plt.ylabel('Close')\n",
" plt.title(asset + ': Closes vs Predictions')\n",
"\n",
" # Adding legend\n",
" plt.legend()\n",
" plt.savefig(metrics_plot_path+asset+'_True_vs_Pred.pdf')\n",
" # Displaying the plot\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "8bcbf004",
Expand Down Expand Up @@ -896,7 +940,7 @@
"train_edge_weights_f = np.concatenate((train_edge_weights, val_edge_weights), axis=0)\n",
"train_node_features_f = np.concatenate((train_node_features, val_node_features), axis=0)\n",
"train_node_labels_f = np.concatenate((train_node_labels, val_node_labels), axis=0)\n",
"epochs = 30\n",
"epochs = 15\n",
"\n",
"train_loader,test_loader = graph_data_loader(batch_size, train_edges_f, train_edge_weights_f, train_node_features_f,\n",
" train_node_labels_f, test_edges, test_edge_weights, test_node_features, test_node_labels)\n",
Expand Down Expand Up @@ -945,6 +989,24 @@
"plot_single_metric(metrics_test['eval_mae_ls'],'MAE','Test',metrics_plot_path,gcn)"
]
},
{
"cell_type": "markdown",
"id": "53ca301d",
"metadata": {},
"source": [
"### Plot Predicted and True Closes"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dd220e76",
"metadata": {},
"outputs": [],
"source": [
"plot_precitions_vs_actuals(pivot_true,pivot_pred,'BTC',metrics_plot_path)"
]
},
{
"cell_type": "markdown",
"id": "f08f199a",
Expand Down Expand Up @@ -981,9 +1043,9 @@
"metadata": {},
"outputs": [],
"source": [
"epxeriment_name = 'Final_Proposed'\n",
"with open(results_path+gcn+'_'+epxeriment_name+'.txt', \"w\") as file:\n",
" file.write(str(metrics_test['eval_mae_ls']))"
"# epxeriment_name = 'Final_Proposed'\n",
"# with open(results_path+gcn+'_'+epxeriment_name+'.txt', \"w\") as file:\n",
"# file.write(str(metrics_test['eval_mae_ls']))"
]
},
{
Expand Down Expand Up @@ -1042,14 +1104,6 @@
"plt.savefig(results_path+\"Edges_Infulence.pdf\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1c186bb6",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Binary file modified io/output/exports/experiments_results/Edges_Infulence.pdf
Binary file not shown.
Binary file not shown.
Binary file modified io/output/exports/metrics_plots/TGCN2_Test_MAE.pdf
Binary file not shown.
Binary file modified io/output/exports/metrics_plots/TGCN2_Test_metrics.pdf
Binary file not shown.
Loading

0 comments on commit 7d20733

Please sign in to comment.