diff --git a/gui.py b/gui.py index 19b0c51..0755841 100644 --- a/gui.py +++ b/gui.py @@ -46,7 +46,11 @@ SAM_MODEL_TYPE = "vit_b" MedSAM_CKPT_PATH = "work_dir/MedSAM/medsam_vit_b.pth" MEDSAM_IMG_INPUT_SIZE = 1024 -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +if torch.backends.mps.is_available(): + device = torch.device("mps") +else: + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @torch.no_grad() diff --git a/segment_anything/build_sam.py b/segment_anything/build_sam.py index c8a8dc6..ace39e5 100644 --- a/segment_anything/build_sam.py +++ b/segment_anything/build_sam.py @@ -141,6 +141,6 @@ def _build_sam( if checkpoint is not None: with open(checkpoint, "rb") as f: - state_dict = torch.load(f) + state_dict = torch.load(f, map_location=torch.device('cpu')) sam.load_state_dict(state_dict) return sam