Skip to content

Commit

Permalink
Add NLLLoss to DTensor prop rule (pytorch#98512)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#98512
Approved by: https://github.com/wanchaol
  • Loading branch information
mrshenli authored and pytorchmergebot committed Apr 8, 2023
1 parent a6155f3 commit d255c8e
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 13 deletions.
55 changes: 42 additions & 13 deletions test/distributed/_spmd/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,32 +690,27 @@ class CoverageTest(DTensorTestBase):
def world_size(self):
return 2

def _test_train_step(self, mod, inp):
@compile()
def train_step(mod, opt, inp):
mod(inp).sum().backward()
opt.step()

def _test_train_step(self, train_step, mod, *args):
ddp_mod = DDP(deepcopy(mod), device_ids=[self.rank])

opt = torch.optim.SGD(mod.parameters(), lr=0.01, foreach=True)
ddp_opt = torch.optim.SGD(ddp_mod.parameters(), lr=0.01, foreach=True)

ddp_inp = deepcopy(inp)
ddp_args = deepcopy(args)

# materialize optimizer states
mod(inp).sum().backward()
mod(*args).sum().backward()
opt.step()
opt.zero_grad()

ddp_mod(ddp_inp).sum().backward()
ddp_mod(*ddp_args).sum().backward()
ddp_opt.step()
ddp_opt.zero_grad()

# test parameter parity
train_step(mod, opt, inp)
train_step(mod, opt, *args)

ddp_mod(ddp_inp).sum().backward()
ddp_mod(*ddp_args).sum().backward()
# FIXME(@mrshenli): DDP by default divides grads by world size, but
# torch.distributed.compile does not do that yet.
with torch.no_grad():
Expand All @@ -730,12 +725,46 @@ def train_step(mod, opt, inp):
@with_comms
def test_log_softmax(self):
torch.manual_seed(0)

@compile()
def train_step(mod, opt, inp):
mod(inp).sum().backward()
opt.step()

mod = nn.Sequential(
nn.Linear(10, 10),
nn.LogSoftmax(dim=1),
).cuda(self.rank)
inp = torch.randn(20, 10).cuda(self.rank)
self._test_train_step(mod, inp)
inp = torch.randn(2, 10).cuda(self.rank)
self._test_train_step(train_step, mod, inp)

@skip_if_lt_x_gpu(2)
@with_comms
def test_nll_loss(self):
class ModuleWithLoss(nn.Module):
def __init__(self):
super().__init__()
self.mod = nn.Sequential(
nn.Linear(10, 10),
nn.LogSoftmax(dim=1),
)
self.lss = nn.NLLLoss()

def forward(self, x, tgt):
return self.lss(self.mod(x), tgt)

torch.manual_seed(0)
mod = ModuleWithLoss().cuda(self.rank)

@compile()
def train_step(mod, opt, inp, tgt):
mod(inp, tgt).backward()
opt.step()

inp = torch.randn(2, 10).to(self.rank)
tgt = torch.empty(2, dtype=torch.long).random_(0, 10).to(self.rank)

self._test_train_step(train_step, mod, inp, tgt)


if __name__ == "__main__":
Expand Down
48 changes: 48 additions & 0 deletions torch/distributed/_spmd/experimental_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,54 @@ def _prop__fused_adam(op_schema: OpSchema):
return OutputSharding(output_spec=(op_schema.args_schema[0],) * NT) # type: ignore[arg-type]


@register_prop_rule(aten.nll_loss_forward.default) # pyre-ignore
def _prop_nll_loss_forward(op_schema: OpSchema) -> OutputSharding:
self, target = op_schema.args_schema[:2]
assert isinstance(self, DTensorSpec)
assert isinstance(target, DTensorSpec)
if self.placements != target.placements:
# Self and target must match in placements, which should be shard along
# batch dimension in data parallell use cases. Force redistribute.

# need to create a new self instead return (target, target) as target
# and self might not match in shape.
new_self = DTensorSpec(
mesh=self.mesh,
placements=target.placements,
tensor_meta=self.tensor_meta,
)
return OutputSharding(
output_spec=None,
schema_suggestions=[
OpSchema(
func_schema=op_schema.func_schema,
args_schema=(new_self, target) + op_schema.args_schema[2:],
kwargs_schema=op_schema.kwargs_schema,
is_inplace=op_schema.is_inplace,
is_out_variant=op_schema.is_out_variant,
)
],
)
else:
return OutputSharding(
output_spec=(
# by default, nll_loss_forward conducts a reduction and returns
# a scalar tensor, and hence the _Partial placements.
DTensorSpec(mesh=self.mesh, placements=[_Partial()]),
# the 2nd output total_weight is always a scalar tensor
DTensorSpec(mesh=self.mesh, placements=[Replicate()]),
)
)


@register_prop_rule(aten.nll_loss_backward.default) # pyre-ignore
def _prop_nll_loss_backward(op_schema: OpSchema) -> OutputSharding:
grad_output, self = op_schema.args_schema[:2]
assert isinstance(grad_output, DTensorSpec)
assert isinstance(self, DTensorSpec)
return OutputSharding(output_spec=self)


@register_prop_rule(aten.native_layer_norm.default) # pyre-ignore
def _prop_native_layer_norm(op_schema: OpSchema) -> OutputSharding:
input, normalized_shape, weight, bias, eps = op_schema.args_schema
Expand Down

0 comments on commit d255c8e

Please sign in to comment.