-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_pipelines.py
73 lines (55 loc) · 2.11 KB
/
train_pipelines.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from transformers import get_cosine_schedule_with_warmup
def train_epoch(model, loader, optimizer, loss_fn, scheduler, device='cuda'):
running_loss = 0
number_of_samples = 0
pbar = tqdm(total = len(loader), desc='Training', position=0, leave=True)
for _, (images, targets) in enumerate(loader):
images, targets = images.to(device), targets.to(device)
model.train()
optimizer.zero_grad()
predictions = model(images)
loss = loss_fn(predictions, targets)
loss.backward()
optimizer.step()
scheduler.step()
running_loss += loss.item() #*images.shape[0]
number_of_samples += images.shape[0]
pbar.update()
pbar.close()
return running_loss/number_of_samples
def test_model(model, testloader, loss_fn, verbose=False, device='cuda'):
all_prediction = []
all_targets = []
pbar = tqdm(total = len(testloader), desc='Testing', position=0, leave=True)
with torch.no_grad():
for _, (images, targets) in enumerate(testloader):
model.eval()
all_prediction += model(images.to(device)).cpu().numpy().tolist()
all_targets += targets.cpu().numpy().tolist()
pbar.update()
pbar.close()
return loss_fn(torch.tensor(all_prediction), torch.tensor(all_targets))/len(all_targets)
def train_test(
model,
optimizer,
scheduler,
trainloader,
testloader,
loss_fn,
num_epochs=30,
verbose=False,
device='cuda',
):
train_losses = []
test_losses = []
for epoch in range(num_epochs):
train_loss = train_epoch(model, trainloader, optimizer, loss_fn, scheduler, device)
train_losses.append(train_loss)
test_loss = test_model(model, testloader, loss_fn, verbose, device)
test_losses.append(test_loss)
if verbose:
tqdm.write('Epoch: '+ str(epoch) + ', Train loss: ' + str(train_loss) + ', Test loss: ' + str(test_loss))
return train_losses, test_losses