Skip to content

Commit

Permalink
Plot training history
Browse files Browse the repository at this point in the history
  • Loading branch information
bpesquet committed Sep 24, 2024
1 parent f8fc1cc commit b1645e6
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ good-names=i,j,k,x,y
max-args = 6

# Maximum number of locals for function / method body
max-locals = 17
max-locals = 18

[TYPECHECK]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch import nn
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -50,6 +51,31 @@ def plot_fashion_images(data, device, model=None):
plt.show()


def plot_loss_acc(history, dataset="Training"):
"""Plot training loss and accuracy. Takes a Keras-like History object as parameter"""

loss_values = history["loss"]
recorded_epochs = range(1, len(loss_values) + 1)

fig, (ax1, ax2) = plt.subplots(2, 1)
ax1.plot(recorded_epochs, loss_values, ".--", label=f"{dataset} loss")
ax1.set_ylabel("Loss")
ax1.legend()

acc_values = history["acc"]
ax2.plot(recorded_epochs, acc_values, ".--", label=f"{dataset} accuracy")
ax2.set_xlabel("Epochs")
ax2.set_ylabel("Accuracy")
plt.legend()

final_loss = loss_values[-1]
final_acc = acc_values[-1]
fig.suptitle(
f"{dataset} loss: {final_loss:.5f}. {dataset} accuracy: {final_acc*100:.2f}%"
)
plt.show()


def fetch_fashion_dataset(data_folder):
"""Download the Fashion-MNIST images dataset"""

Expand Down Expand Up @@ -118,12 +144,17 @@ def forward(self, x):
def fit(model, dataloader, criterion, optimizer, n_epochs, device):
"""Train a model on a dataset, using a predefined gradient descent optimizer"""

# Object storing training history
history = {"loss": [], "acc": []}

# Number of samples
n_samples = len(dataloader.dataset)

# Number of batches in an epoch (= n_samples / batch_size, rounded up)
n_batches = len(dataloader)

print(f"Training started! {n_samples} samples. {n_batches} batches per epoch")

# Train the model
for epoch in range(n_epochs):
# Total loss for epoch, divided by number of batches to obtain mean loss
Expand Down Expand Up @@ -160,6 +191,12 @@ def fit(model, dataloader, criterion, optimizer, n_epochs, device):
f"Epoch [{(epoch + 1):3}/{n_epochs:3}] finished. Mean loss: {mean_loss:.5f}. Accuracy: {epoch_acc * 100:.2f}%"
)

# Record epoch metrics for later plotting
history["loss"].append(mean_loss)
history["acc"].append(epoch_acc)

return history


def evaluate(model, dataloader, device):
"""Evaluate a model in inference mode"""
Expand Down Expand Up @@ -221,12 +258,18 @@ def test_feedforward_neural_network_fashion_images(show_plots=False):
# Adam optimizer for GD
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

fit(model, train_loader, criterion, optimizer, n_epochs, device)
history = fit(model, train_loader, criterion, optimizer, n_epochs, device)

# Evaluate model performance on test data
evaluate(model, test_loader, device)

if show_plots:
# Improve plots appearance
sns.set_theme()

# Plot training history
plot_loss_acc(history)

# Plot model predictions for some test images
plot_fashion_images(test_dataset, device, model)

Expand Down

0 comments on commit b1645e6

Please sign in to comment.