Skip to content

Commit

Permalink
Update main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
flow3rdown authored Jul 28, 2024
1 parent ab28122 commit a722538
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion MarT/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def load_state_dict():
trainer.fit(lit_model, datamodule=data)
path = model_checkpoint.best_model_path
# load best model
lit_model.load_state_dict(torch.load(path, map_location='cuda:2')["state_dict"])
lit_model.load_state_dict(torch.load(path, map_location='cuda')["state_dict"])

result = trainer.test(lit_model, datamodule=data)
print(result)
Expand Down

0 comments on commit a722538

Please sign in to comment.