Skip to content

Commit

Permalink
fix multi gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
iloveOREO committed Jan 8, 2025
1 parent a9afce4 commit c026e0f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions preprocess/filter_visual_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def func(paths, device_id):
model_hyper.train(False)

# load the pre-trained model on the koniq-10k dataset
model_hyper.load_state_dict((torch.load("checkpoints/auxiliary/koniq_pretrained.pkl")))
model_hyper.load_state_dict((torch.load("checkpoints/auxiliary/koniq_pretrained.pkl", map_location=device)))

transforms = torchvision.transforms.Compose(
[
Expand All @@ -78,7 +78,7 @@ def func(paths, device_id):
paras = model_hyper(video_frames) # 'paras' contains the network weights conveyed to target network

# Building target network
model_target = TargetNet(paras).cuda()
model_target = TargetNet(paras).to(device)
for param in model_target.parameters():
param.requires_grad = False

Expand Down

0 comments on commit c026e0f

Please sign in to comment.