forked from dvgodoy/PyTorchStepByStep
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_training.py
36 lines (30 loc) · 1.04 KB
/
model_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""
training loop
"""
# import torch
# from data_preparation import device, train_loader, val_loader,test_variable
# from model_configuration import model, one_training_step_fn, one_val_step_fn
from utils import mini_batches_over_epoch
n_epochs = 200
train_losses = []
val_losses = []
for epoch in range(n_epochs):
train_epoch_loss = mini_batches_over_epoch(device, train_loader, one_training_step_fn)
# print(f"epoch_loss: {train_epoch_loss}")
train_losses.append(train_epoch_loss)
# VALIDATION - no gradients in validation!
with torch.no_grad():
val_epoch_loss = mini_batches_over_epoch(device, val_loader, one_val_step_fn)
# print(f"epoch_loss: {val_epoch_loss}")
val_losses.append(val_epoch_loss)
writer.add_scalars(
main_tag="loss",
tag_scalar_dict={
"training": train_losses[-1],
"validation": val_losses[-1]},
global_step=epoch,
)
writer.close()
# print(f"test_variable: {test_variable}")
print(f"Done {__file__.__repr__()}")
print(model.state_dict())