Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bubbliiiing authored Jun 4, 2020
1 parent 83a0995 commit fb8d176
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ def fit_ont_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,genval,Epo
Cuda = True

print('Loading weights into state dict...')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_dict = model.state_dict()
pretrained_dict = torch.load("model_data/yolo_weights.pth")
pretrained_dict = torch.load("model_data/yolo_weights.pth", map_location=device)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
Expand Down

0 comments on commit fb8d176

Please sign in to comment.