diff --git a/MarT/main.py b/MarT/main.py index 8576c0f..fdc915e 100644 --- a/MarT/main.py +++ b/MarT/main.py @@ -156,7 +156,7 @@ def load_state_dict(): trainer.fit(lit_model, datamodule=data) path = model_checkpoint.best_model_path # load best model - lit_model.load_state_dict(torch.load(path, map_location='cuda:2')["state_dict"]) + lit_model.load_state_dict(torch.load(path, map_location='cuda')["state_dict"]) result = trainer.test(lit_model, datamodule=data) print(result)