Skip to content

Commit

Permalink
Contigous on gather (huggingface#1771)
Browse files Browse the repository at this point in the history
* For testing

* Contigous
  • Loading branch information
muellerzr authored Jul 25, 2023
1 parent 6e70e79 commit c3d82d2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/accelerate/test_utils/scripts/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def test_gather_object(state):
assert gathered_obj == list(range(state.num_processes)), f"{gathered_obj} != {list(range(state.num_processes))}"


def test_gather_non_contigous(state):
# Create a non-contiguous tensor
tensor = torch.arange(12).view(4, 3).t().to(state.device)
assert not tensor.is_contiguous()
# Shouldn't error out
_ = gather(tensor)


def test_broadcast(state):
tensor = create_tensor(state)
broadcasted_tensor = broadcast(tensor)
Expand Down Expand Up @@ -133,6 +141,8 @@ def main():
test_gather(state)
state.print("testing gather_object")
test_gather_object(state)
state.print("testing gather non-contigous")
test_gather_non_contigous(state)
state.print("testing broadcast")
test_broadcast(state)
state.print("testing pad_across_processes")
Expand Down
7 changes: 7 additions & 0 deletions src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ def _tpu_gather_one(tensor):
if tensor.ndim == 0:
tensor = tensor.clone()[None]

# Can only gather contiguous tensors
if not tensor.is_contiguous():
tensor = tensor.contiguous()
return xm.all_gather(tensor)

res = recursively_apply(_tpu_gather_one, tensor, error_on_other_type=True)
Expand All @@ -280,6 +283,10 @@ def _gpu_gather(tensor):
def _gpu_gather_one(tensor):
if tensor.ndim == 0:
tensor = tensor.clone()[None]

# Can only gather contiguous tensors
if not tensor.is_contiguous():
tensor = tensor.contiguous()
output_tensors = [torch.empty_like(tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
return torch.cat(output_tensors, dim=0)
Expand Down

0 comments on commit c3d82d2

Please sign in to comment.