Skip to content

Commit

Permalink
Updating stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
Netzuel committed Oct 4, 2023
1 parent 52e82da commit ab4ebfa
Show file tree
Hide file tree
Showing 22 changed files with 1,432,804 additions and 4 deletions.
176 changes: 176 additions & 0 deletions Check_Performance_of_Model.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "83f8ecbf",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"import torch.nn as nn\n",
"import matplotlib.pyplot as plt\n",
"import os\n",
"import sys\n",
"from tqdm import tqdm\n",
"import torch.nn.functional as F\n",
"import math\n",
"import random\n",
"import h5py\n",
"# Import utils and models\n",
"import utils, models\n",
"\n",
"print(\"Is PyTorch using GPU?\", torch.cuda.is_available())\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"DTYPE = torch.float32"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8bc27adc",
"metadata": {},
"outputs": [],
"source": [
"print(\"Loading data...\")\n",
"# Import data\n",
"hf = h5py.File(\"Data/GP_galactic_data.h5\", \"r\")\n",
"\n",
"X_training = torch.tensor( np.array(hf.get(\"X_training\")), dtype = DTYPE )\n",
"Y_training = torch.tensor( np.array(hf.get(\"Y_training\")), dtype = DTYPE )\n",
"\n",
"X_val = torch.tensor( np.array(hf.get(\"X_val\")), dtype = DTYPE )\n",
"Y_val = torch.tensor( np.array(hf.get(\"Y_val\")), dtype = DTYPE )\n",
"\n",
"X_test = torch.tensor( np.array(hf.get(\"X_test\")), dtype = DTYPE )\n",
"Y_test = torch.tensor( np.array(hf.get(\"Y_test\")), dtype = DTYPE )\n",
"\n",
"hf.close()\n",
"\n",
"print(\"Training shape: \", X_training.shape, Y_training.shape)\n",
"print(\"Val shape: \", X_val.shape, Y_val.shape)\n",
"print(\"Test shape: \", X_test.shape, Y_test.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "78670814",
"metadata": {},
"outputs": [],
"source": [
"batch_size = 5\n",
"n_batches = int( X_test.shape[0] / batch_size )\n",
"\n",
"labeling_order = torch.unique(Y_test, return_counts = False).type(torch.int).detach().cpu().numpy().tolist()\n",
"print( \"labeling_order: \", labeling_order )\n",
"\n",
"# Compute weights from frequency of classes in the training set\n",
"unique_training, counts_training = torch.unique( Y_test, return_counts = True )\n",
"\n",
"inverse_freqs = 1 / counts_training\n",
"weights_tensor = ( inverse_freqs / torch.sum(inverse_freqs) ).to(device)\n",
"\n",
"# Define model\n",
"input_dim = 6\n",
"n_classes = len(labeling_order)\n",
"d_model = 32\n",
"nhead = 4\n",
"num_layers = 4\n",
"\n",
"model = models.TransformerClassifier(input_dim = input_dim, n_classes = n_classes, d_model = d_model, nhead = nhead, num_layers = num_layers, weights_tensor = weights_tensor, DTYPE = DTYPE, device = device, labeling_order = labeling_order).to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3cd7e61e",
"metadata": {},
"outputs": [],
"source": [
"# Load weights\n",
"model.load_state_dict( torch.load(\"Different_Tests/Test_1/Models_Data/Model_Saved/model_epoch_61.pt\") )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bbf1f331",
"metadata": {},
"outputs": [],
"source": [
"y_true_list, y_pred_list = [], []\n",
"\n",
"for n in tqdm(range(n_batches)):\n",
" y_pred_batch = model(X_test[n*batch_size:(n+1)*batch_size,0,:,:].to(device))\n",
" y_pred_batch = torch.tensor([labeling_order[i] for i in torch.argmax(y_pred_batch, dim = 1).detach().cpu().numpy().tolist()], dtype = DTYPE).view(-1,1)\n",
" y_true = Y_test[n*batch_size:(n+1)*batch_size]\n",
" y_true_list.append(y_true)\n",
" y_pred_list.append(y_pred_batch)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fa8d1624",
"metadata": {},
"outputs": [],
"source": [
"y_pred_list"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "96409b9a",
"metadata": {},
"outputs": [],
"source": [
"model(X_training[n*batch_size:(n+1)*batch_size,0,:,:].to(device))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cf12d396",
"metadata": {},
"outputs": [],
"source": [
"torch.unique(Y_test, return_counts = True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "174d13a6",
"metadata": {},
"outputs": [],
"source": [
"y_true"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit ab4ebfa

Please sign in to comment.