Skip to content

Commit

Permalink
remove .pin_memory() in obj_pos of SAM2Base to resolve and erro…
Browse files Browse the repository at this point in the history
…r in MPS (facebookresearch#495)

In this PR, we remove `.pin_memory()` in `obj_pos` of `SAM2Base` to resolve and error in MPS. Investigations show that `.pin_memory()` causes an error of `Attempted to set the storage of a tensor on device "cpu" to a storage on different device "mps:0"`, as originally reported in facebookresearch#487.

(close facebookresearch#487)
  • Loading branch information
ronghanghu authored Dec 16, 2024
1 parent 722d1d1 commit 2b90b9f
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions sam2/modeling/sam2_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,10 +628,8 @@ def _prepare_memory_conditioned_features(
if self.add_tpos_enc_to_obj_ptrs:
t_diff_max = max_obj_ptrs_in_encoder - 1
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
obj_pos = (
torch.tensor(pos_list)
.pin_memory()
.to(device=device, non_blocking=True)
obj_pos = torch.tensor(pos_list).to(
device=device, non_blocking=True
)
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
Expand Down

0 comments on commit 2b90b9f

Please sign in to comment.