forked from dvgodoy/PyTorchStepByStep
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
60 lines (42 loc) · 1.33 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import numpy as np
from typing import Callable
def get_one_training_step_fn(model, loss_fn: Callable, optimizer) -> Callable:
"""
TODO: redo with partial
"""
def one_training_step_fn(x: torch.Tensor, y: torch.Tensor) -> float:
# set model to train mode
model.train()
# pred
y_pred_tensor = model(x)
# loss -> number?
loss = loss_fn(y_pred_tensor, y)
# loss backprop
loss.backward()
# update parameters
optimizer.step()
optimizer.zero_grad()
return loss.item()
return one_training_step_fn
def get_one_val_step_fn(model, loss_fn: Callable) -> Callable:
"""
TODO: redo with partial
"""
def one_val_step_fn(x: torch.Tensor, y: torch.Tensor) -> float:
# set model to eval mode
model.eval()
y_pred_tensor = model(x)
# loss -> number?
loss = loss_fn(y_pred_tensor, y)
return loss.item()
return one_val_step_fn
def mini_batches_over_epoch(device, data_loader, one_step_fn) -> float:
batch_losses = []
for x_batch, y_batch in data_loader:
# important!
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
loss = one_step_fn(x_batch, y_batch)
batch_losses.append(loss)
return np.mean(batch_losses)