Skip to content

Commit

Permalink
Revert "[experimental] async-tp impl with cutlass-based, progress awa…
Browse files Browse the repository at this point in the history
…re kernel (pytorch#139227)"

This reverts commit 5203138.

Reverted pytorch#139227 on behalf of https://github.com/yifuwang due to Need to address internal build failure D65605027 ([comment](pytorch#139227 (comment)))
  • Loading branch information
pytorchmergebot committed Nov 7, 2024
1 parent d378819 commit 36e0f11
Show file tree
Hide file tree
Showing 10 changed files with 6 additions and 961 deletions.
1 change: 0 additions & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,6 @@ cc_library(
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
"torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu",
"torch/csrc/distributed/c10d/cuda/AsyncMM.cu",
"torch/csrc/distributed/c10d/NanCheck.cu",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
],
Expand Down
1 change: 0 additions & 1 deletion build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,6 @@ libtorch_cuda_distributed_extra_sources = [
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
"torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu",
"torch/csrc/distributed/c10d/cuda/AsyncMM.cu",
"torch/csrc/distributed/c10d/NanCheck.cu",
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
Expand Down
9 changes: 0 additions & 9 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -567,15 +567,6 @@ if(USE_CUDA)
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
)
endif()

set(ASYNC_MM_FILE "${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/AsyncMM.cu")
# Disable the warning to make cutlass warp-specialized cooperative kernel build for gcc-9
if(CMAKE_COMPILER_IS_GNUCXX)
set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-Wno-unused-but-set-variable")
endif()
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0 AND EXISTING_ARCH_FLAGS MATCHES ".*compute_90.*")
set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a")
endif()
endif()
set_source_files_properties(
${TORCH_ROOT}/aten/src/ATen/cuda/detail/LazyNVRTC.cpp
Expand Down
52 changes: 0 additions & 52 deletions test/distributed/test_symmetric_memory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Owner(s): ["module: c10d"]

import os
from unittest import skipIf

import torch
import torch.distributed as dist
Expand All @@ -11,15 +10,13 @@
from torch.distributed._functional_collectives import all_gather_tensor
from torch.distributed._symmetric_memory import (
_fused_all_gather_matmul_fallback,
_fused_all_gather_matmul_native,
_fused_all_gather_scaled_matmul_fallback,
_fused_matmul_reduce_scatter_fallback,
_fused_scaled_matmul_reduce_scatter_fallback,
enable_symm_mem_for_group,
restride_A_for_fused_matmul_reduce_scatter,
restride_A_shard_for_fused_all_gather_matmul,
)
from torch.testing._internal.common_cuda import SM90OrLater
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
skip_if_lt_x_gpu,
Expand Down Expand Up @@ -323,55 +320,6 @@ def test_fused_all_gather_matmul(self, gather_dim: int) -> None:

dist.destroy_process_group()

@skipIfRocm
@skipIf(
not SM90OrLater,
"_fused_all_gather_matmul_native currently only supports sm>=90",
)
@skip_if_lt_x_gpu(2)
@parametrize("symm_mem_input", [True, False])
@parametrize("is_b_row_major", [True, False])
def test_fused_all_gather_matmul_native(
self, symm_mem_input: bool, is_b_row_major: bool
) -> None:
self._init_process()

M = 1024
N = 1024
K = 1024
group_name = dist.group.WORLD.group_name

torch.manual_seed(42 + self.rank)
if symm_mem_input:
A_shard = _SymmetricMemory.empty_strided_p2p(
size=(M // self.world_size, K),
stride=(K, 1),
dtype=torch.bfloat16,
device=self.device,
group_name="0",
).normal_()
else:
A_shard = torch.rand(
M // self.world_size, K, dtype=torch.bfloat16, device="cuda"
)

if is_b_row_major:
B = torch.rand(K, N, dtype=torch.bfloat16, device="cuda")
else:
B = torch.rand(N, K, dtype=torch.bfloat16, device="cuda").t()

ag_baseline, mm_baseline = _fused_all_gather_matmul_fallback(
A_shard, [B], gather_dim=0, group_name=group_name
)
ag_target, mm_target = _fused_all_gather_matmul_native(
A_shard, B, group_name=group_name
)

torch.testing.assert_close(ag_target, ag_baseline)
torch.testing.assert_close(mm_target, mm_baseline[0])

dist.destroy_process_group()

@skipIfRocm
@skip_if_lt_x_gpu(2)
@parametrize("gather_dim", [0, 1])
Expand Down
1 change: 0 additions & 1 deletion torch/_C/_distributed_c10d.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -668,4 +668,3 @@ class _SymmetricMemory:
def barrier(self, channel: int = 0) -> None: ...
def put_signal(self, dst_rank: int, channel: int = 0) -> None: ...
def wait_signal(self, src_rank: int, channel: int = 0) -> None: ...
def stream_write_value32(self, addr: int, val: int) -> None: ...
17 changes: 0 additions & 17 deletions torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

#include <torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h>
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp>
#include <torch/csrc/distributed/c10d/cuda/AsyncMM.cuh>

#define INT_SWITCH_CASE(name, val, ...) \
case val: { \
Expand Down Expand Up @@ -594,22 +593,6 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
"two_shot_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)",
torch::dispatch(c10::DispatchKey::CUDA, ::two_shot_all_reduce_),
{at::Tag::pt2_compliant_tag});
// An mm that supports consuming asynchronous input. It guarantees the
// following rasterization order, and that the corresponding signal arrives
// before an input chunk is consumed.
//
// num_chunks = a_chunks_signals.numel()
// for chunk_idx in range(a_chunk_pivot, num_chunks + a_chunk_pivot):
// chunk_idx = chunk_idx % num_chunks
// wait_signal(a_chunk_signals, chunk_idx)
// # Compute output tiles that consumes the input chunk
m.def(
"_async_input_mm(Tensor a, Tensor b, Tensor a_chunk_signals, int a_chunk_pivot) -> Tensor",
torch::dispatch(
c10::DispatchKey::CUDA, c10d::cuda::detail::async_input_mm),
{at::Tag::pt2_compliant_tag});
#endif
m.def(
"memset32_(Tensor(a!) input, int offset, int val, int count) -> Tensor(a!)",
Expand Down
Loading

0 comments on commit 36e0f11

Please sign in to comment.