Skip to content

Commit

Permalink
[CP] Add assertion for unsupported load-balance + non-causal (pytorch…
Browse files Browse the repository at this point in the history
…#141622)

We actually do not support load-balance mode when non_causal = True, due
to changes in data shuffling for load_balance mode.  This PR just adds
an assertion to make this limitation clear.

Fixes pytorch#141429

Pull Request resolved: pytorch#141622
Approved by: https://github.com/XilunWu
  • Loading branch information
wconstab authored and pytorchmergebot committed Nov 28, 2024
1 parent b556549 commit 54d26d6
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions torch/distributed/tensor/experimental/_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,8 @@ def _templated_ring_attention(
raise NotImplementedError(
"is_causal requires the same query and context sequence lengths"
)
if not is_causal and _cp_options.enable_load_balance:
raise RuntimeError("Load balancing requires `is_causal=True`.")

if isinstance(mesh, dist.ProcessGroup):
pg: Union[dist.ProcessGroup, List[dist.ProcessGroup]] = mesh
Expand Down Expand Up @@ -616,6 +618,8 @@ def _templated_ring_attention_backward(
**kwargs: Any,
) -> Tuple[torch.Tensor, ...]:
"""This API implements the backward of the ring attention."""
if not is_causal and _cp_options.enable_load_balance:
raise RuntimeError("Load balancing requires `is_causal=True`.")
pg = mesh.get_group()
assert isinstance(pg, dist.ProcessGroup), "must be single dimension"
rank = dist.get_rank(pg)
Expand Down

0 comments on commit 54d26d6

Please sign in to comment.