forked from comp-well-org/ESI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
84 lines (69 loc) · 2.98 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import time
import math
import torch
from accelerate import Accelerator
from accelerate import DistributedDataParallelKwargs
from transformers import get_linear_schedule_with_warmup
from torch.cuda.amp import autocast
from tqdm import tqdm
import os
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def train_epochs(model, dataloader, epoch, optimizer, scheduler, args, tb_writer=None):
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs],
log_with="wandb",
project_dir=args.save_path)
accelerator.init_trackers(
project_name="esi",
config={"dropout": 0.1, "learning_rate": args.lr},
)
device = accelerator.device
num_training_steps = args.epochs * len(dataloader)
warmup_steps = int(0.1 * num_training_steps)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
model, optimizer, dataloader, scheduler = accelerator.prepare(
model, optimizer, dataloader, scheduler
)
# model.to(device)
model.train()
data_time_m = AverageMeter()
end = time.time()
step = 0
for e in range(epoch):
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
with autocast():
signals, texts = batch
assert not torch.isnan(signals).any()
# signals = signals.to(device=device, non_blocking=True)
# texts = texts.to(device=device, non_blocking=True)
input_ids = texts["input_ids"]
attention_mask = texts["attention_mask"]
data_time_m.update(time.time() - end)
optimizer.zero_grad()
if args.signal_encoder == "xresnet1d101":
signals = signals.permute(0, 2, 1)
loss, caption_loss, contrastive_loss = model(images=signals, text=input_ids, attention_mask=attention_mask, return_loss=True)
# print(loss)
accelerator.backward(loss)
optimizer.step()
scheduler.step()
accelerator.log({"training_loss": loss}, step=step)
accelerator.log({"caption_loss": caption_loss}, step=step)
accelerator.log({"contrastive_loss": contrastive_loss}, step=step)
step += 1
accelerator.save_model(model, os.path.join(args.save_path, f"epoch_{e}"))
# if args.clip_grad > 0:
# torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)