Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
nikitabelooussovbtis committed Jun 21, 2022
1 parent e9309af commit 1194650
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,26 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
f"Logging results to {colorstr('bold', save_dir)}\n"
f'Starting training for {epochs} epochs...')
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
# Trainloader
train_loader, dataset = create_dataloader(train_path,
imgsz,
batch_size // WORLD_SIZE,
gs,
single_cls,
hyp=hyp,
augment=True,
cache=None if opt.cache == 'val' else opt.cache,
rect=opt.rect,
rank=LOCAL_RANK,
workers=workers,
image_weights=opt.image_weights,
quad=opt.quad,
prefix=colorstr('train: '),
shuffle=True)
mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
nb = len(train_loader) # number of batches
assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'

callbacks.run('on_train_epoch_start')
model.train()

Expand Down

0 comments on commit 1194650

Please sign in to comment.