-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
74 lines (61 loc) · 1.99 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from tqdm import tqdm
import gc
import utils
from model import YOLO
from loss import YOLOLoss
from dataset import BananasDataset
LEARNING_RATE = 2e-5
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 4
WEIGHT_DECAY = 0
EPOCHS = 15
PIN_MEMORY = True
LOAD_MODEL = False
NUM_WORKERS = 2
def train_fn(train_loader, model, optimizer, loss_fn):
loop = tqdm(train_loader, leave=True)
mean_loss = []
for batch_idx, (x, y) in enumerate(loop):
x, y = x.to(DEVICE), y.to(DEVICE)
# print(x.shape, y.shape)
out = model(x)
# print(out.shape)
loss = loss_fn(out, y)
# tensor.item() return a number
mean_loss.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
loop.set_postfix(loss=loss.item())
print(f"Mean loss was {sum(mean_loss) / len(mean_loss)}")
def main():
transform = transforms.Compose([transforms.Resize((448, 448)), transforms.ToTensor()])
model = YOLO(S=7, B=2, C=1).to(DEVICE)
optimizer = torch.optim.Adam(
model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)
loss_fn = YOLOLoss(S=7, B=2, C=1)
train_dataset = BananasDataset('data/banana-detection/bananas_val/', transform=transform)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
pin_memory=PIN_MEMORY,
shuffle=True,
drop_last=True
)
print('start train')
for epoch in range(EPOCHS):
# pred_boxes, target_boxes = utils.get_bboxes(
# train_loader, model, iou_threshold=0.5, threshold=0.4
# )
print(f'epoch:{epoch + 1}')
train_fn(train_loader, model, optimizer, loss_fn)
gc.collect()
torch.cuda.empty_cache()
torch.save(model.state_dict(), './doc/yolov1-weight-2.pt')
if __name__ == '__main__':
main()