Skip to content

tensor_parallel_example.py and sequence_parallel_example.py #1353

Open
@githubsgi

Description

@githubsgi

The primary difference between the two files are as follows. The TP case , only see 1 allreduce per iteration - is that what is expected ? Seems to be same as DDP ! In the SP case, see 1 allgather and 1 reduce -scatter per iteration.

# Custom parallelization plan for the model
sp_model = parallelize_module(
    module=model,
    device_mesh=device_mesh,
    parallelize_plan={
        "in_proj": ColwiseParallel(input_layouts=Shard(0)),
        "out_proj": RowwiseParallel(output_layouts=Shard(0)),
    },
)

# Custom parallelization plan for the model
tp_model = parallelize_module(
    module=tp_model,
    device_mesh=device_mesh,
    parallelize_plan={
        "in_proj": ColwiseParallel(),
        "out_proj": RowwiseParallel(),
    },
)

CommDebugMode also appears to show 1 allreduce in fwd and no allreduce in bwd.

  FORWARD PASS                                                                                                                                                                                                                 [12/1864]
    *c10d_functional.all_reduce: 1
  BACKWARD PASS
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_reduce: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32]), torch.Size([32])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32, 10]), torch.Size([32, 10])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.all_reduce: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5]), torch.Size([5])]
              sharding: [(Replicate(),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5, 32]), torch.Size([5, 32])]
              sharding: [(Shard(dim=1),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions