Open
Description
The primary difference between the two files are as follows. The TP case , only see 1 allreduce per iteration - is that what is expected ? Seems to be same as DDP ! In the SP case, see 1 allgather and 1 reduce -scatter per iteration.
# Custom parallelization plan for the model
sp_model = parallelize_module(
module=model,
device_mesh=device_mesh,
parallelize_plan={
"in_proj": ColwiseParallel(input_layouts=Shard(0)),
"out_proj": RowwiseParallel(output_layouts=Shard(0)),
},
)
# Custom parallelization plan for the model
tp_model = parallelize_module(
module=tp_model,
device_mesh=device_mesh,
parallelize_plan={
"in_proj": ColwiseParallel(),
"out_proj": RowwiseParallel(),
},
)
CommDebugMode also appears to show 1 allreduce in fwd and no allreduce in bwd.
FORWARD PASS [12/1864]
*c10d_functional.all_reduce: 1
BACKWARD PASS
ToyModel
*module type: class '__main__.ToyModel'
FORWARD PASS
*c10d_functional.all_reduce: 1
ToyModel.in_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=0),)
*bias: (Shard(dim=0),)
FORWARD PASS
**aten.addmm.default
shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([32, 4]), torch.Size([4, 10])]
sharding: [(Shard(dim=0),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 32])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([32]), torch.Size([32])]
sharding: [(Shard(dim=0),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([32, 10]), torch.Size([32, 10])]
sharding: [(Shard(dim=0),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
ToyModel.relu
*module type: class 'torch.nn.modules.activation.ReLU'
FORWARD PASS
BACKWARD PASS
ToyModel.out_proj
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=1),)
*bias: (Replicate(),)
FORWARD PASS
*c10d_functional.all_reduce: 1
**aten.addmm.default
shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
BACKWARD PASS
**aten.mm.default
shape: [torch.Size([4, 5]), torch.Size([5, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.mm.default
shape: [torch.Size([5, 4]), torch.Size([4, 32])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.sum.dim_IntList
shape: [torch.Size([4, 5])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([5]), torch.Size([5])]
sharding: [(Replicate(),), (Replicate(),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
**aten.add_.Tensor
shape: [torch.Size([5, 32]), torch.Size([5, 32])]
sharding: [(Shard(dim=1),), (Shard(dim=1),)]
device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
Metadata
Metadata
Assignees
Labels
No labels