From 28a57b061073127dbac82e1be4d834a24141fbbc Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Tue, 19 Nov 2024 22:43:28 -0800 Subject: [PATCH] NVLS support for msccl++ executor (#375) - Support mote datatype for multicast operation - Add new OP MULTI_LOAD_REDUCE_STORE to support NVLS - Modify allocSharedPhysicalCuda, which return std::shared_ptr instead of std::shared_ptr - Add Python support for allocSharedPhysicalCuda Test passed for `allreduce_nvls.json` --- apps/nccl/src/allreduce.hpp | 2 +- docker/build.sh | 2 +- docs/getting-started/quickstart.md | 4 + include/mscclpp/gpu.hpp | 9 +- .../mscclpp}/gpu_data_types.hpp | 1 + include/mscclpp/gpu_utils.hpp | 155 +- include/mscclpp/nvls.hpp | 20 +- include/mscclpp/nvls_device.hpp | 27 +- python/mscclpp/__init__.py | 1 + python/mscclpp/core_py.cpp | 2 + python/mscclpp/gpu_utils_py.cpp | 30 + python/mscclpp/nvls_py.cpp | 2 +- python/mscclpp_benchmark/mscclpp_op.py | 11 +- python/test/executor_test.py | 18 +- src/executor/execution_plan.cc | 140 +- src/executor/executor.cc | 62 +- src/include/execution_common.hpp | 14 +- src/include/execution_kernel.hpp | 71 +- src/include/execution_plan.hpp | 17 + src/include/registered_memory.hpp | 5 + src/nvls.cc | 54 +- src/registered_memory.cc | 149 +- src/utils.cc | 19 +- test/execution-files/allreduce_nvls.json | 1458 +++++++++++++++++ test/executor_test.cc | 4 + test/nvls_test.cu | 51 +- 26 files changed, 2116 insertions(+), 212 deletions(-) rename {src/include => include/mscclpp}/gpu_data_types.hpp (95%) create mode 100644 python/mscclpp/gpu_utils_py.cpp create mode 100644 test/execution-files/allreduce_nvls.json diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp index c4c1b1a5e..1cd7f3033 100644 --- a/apps/nccl/src/allreduce.hpp +++ b/apps/nccl/src/allreduce.hpp @@ -7,12 +7,12 @@ #include #include #include +#include #include #include #include #include "common.hpp" -#include "gpu_data_types.hpp" template __forceinline__ __device__ To bit_cast(const From& src) { diff --git a/docker/build.sh b/docker/build.sh index 3e2169f68..af4a23025 100755 --- a/docker/build.sh +++ b/docker/build.sh @@ -9,7 +9,7 @@ baseImageTable=( ["cuda12.2"]="nvidia/cuda:12.2.2-devel-ubuntu20.04" ["cuda12.3"]="nvidia/cuda:12.3.2-devel-ubuntu20.04" ["cuda12.4"]="nvidia/cuda:12.4.1-devel-ubuntu22.04" - ["rocm6.2"]="rocm/rocm-terminal:6.2" + ["rocm6.2"]="rocm/rocm-terminal:6.2.1" ) declare -A extraLdPathTable diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md index 8c0982e3e..9eff7e0ec 100644 --- a/docs/getting-started/quickstart.md +++ b/docs/getting-started/quickstart.md @@ -19,6 +19,10 @@ ``` lsmod | grep nvidia_peermem ``` + * For GPU with nvls support, the IMEX channels should be set up (refer [cuMemCreate](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1g899d69a862bba36449789c64b430dc7c)). You can set up the channels manually via: + ``` + sudo nvidia-modprobe -s -i + ``` ## Build with Docker Images diff --git a/include/mscclpp/gpu.hpp b/include/mscclpp/gpu.hpp index f291f610c..914f32e8e 100644 --- a/include/mscclpp/gpu.hpp +++ b/include/mscclpp/gpu.hpp @@ -22,6 +22,7 @@ using CUdeviceptr = hipDeviceptr_t; using CUmemGenericAllocationHandle = hipMemGenericAllocationHandle_t; using CUmemAllocationProp = hipMemAllocationProp; using CUmemAccessDesc = hipMemAccessDesc; +using CUmemAllocationHandleType = hipMemAllocationHandleType; constexpr auto cudaSuccess = hipSuccess; constexpr auto cudaStreamNonBlocking = hipStreamNonBlocking; @@ -86,6 +87,9 @@ constexpr auto CU_MEM_ACCESS_FLAGS_PROT_READWRITE = hipMemAccessFlagsProtReadWri #define cuMemSetAccess(...) hipMemSetAccess(__VA_ARGS__) #define cuMemMap(...) hipMemMap(__VA_ARGS__) #define cuMemUnmap(...) hipMemUnmap(__VA_ARGS__) +#define cuMemRetainAllocationHandle(...) hipMemRetainAllocationHandle(__VA_ARGS__) +#define cuMemExportToShareableHandle(...) hipMemExportToShareableHandle(__VA_ARGS__) +#define cuMemImportFromShareableHandle(...) hipMemImportFromShareableHandle(__VA_ARGS__) #else @@ -97,9 +101,10 @@ constexpr auto CU_MEM_ACCESS_FLAGS_PROT_READWRITE = hipMemAccessFlagsProtReadWri // NVLS #if !defined(__HIP_PLATFORM_AMD__) #include -#define USE_NVLS ((CUDART_VERSION >= 12010) && (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 6, 0))) +// We need CU_MEM_HANDLE_TYPE_FABRIC (instroduced in cuda12.3) to support sharing handles across GPUs via sockets +#define CUDA_NVLS_SUPPORTED ((CUDART_VERSION >= 12030) && (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 6, 0))) #else // !defined(__HIP_PLATFORM_AMD__) -#define USE_NVLS 0 +#define CUDA_NVLS_SUPPORTED 0 #endif // !defined(__HIP_PLATFORM_AMD__) // GPU sync threads diff --git a/src/include/gpu_data_types.hpp b/include/mscclpp/gpu_data_types.hpp similarity index 95% rename from src/include/gpu_data_types.hpp rename to include/mscclpp/gpu_data_types.hpp index 8d2a6fc79..2ec480760 100644 --- a/src/include/gpu_data_types.hpp +++ b/include/mscclpp/gpu_data_types.hpp @@ -16,6 +16,7 @@ using __bfloat162 = __hip_bfloat162; #else #include +#include #if (CUDART_VERSION >= 11000) #include #endif diff --git a/include/mscclpp/gpu_utils.hpp b/include/mscclpp/gpu_utils.hpp index 9be6a7d16..014b9d390 100644 --- a/include/mscclpp/gpu_utils.hpp +++ b/include/mscclpp/gpu_utils.hpp @@ -9,6 +9,7 @@ #include "errors.hpp" #include "gpu.hpp" +#include "utils.hpp" /// Throw @ref mscclpp::CudaError if @p cmd does not return cudaSuccess. /// @param cmd The command to execute. @@ -34,6 +35,19 @@ namespace mscclpp { +/// set memory access permission to read-write +/// @param base Base memory pointer. +/// @param size Size of the memory. +inline void setReadWriteMemoryAccess(void* base, size_t size) { + CUmemAccessDesc accessDesc = {}; + int deviceId; + MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId)); + accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + accessDesc.location.id = deviceId; + accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + MSCCLPP_CUTHROW(cuMemSetAccess((CUdeviceptr)base, size, &accessDesc, 1)); +} + /// A RAII guard that will cudaThreadExchangeStreamCaptureMode to cudaStreamCaptureModeRelaxed on construction and /// restore the previous mode on destruction. This is helpful when we want to avoid CUDA graph capture. struct AvoidCudaGraphCaptureGuard { @@ -53,15 +67,6 @@ struct CudaStreamWithFlags { template struct CudaDeleter; -template -struct PhysicalCudaMemory { - CUmemGenericAllocationHandle memHandle_; - T* devicePtr_; - size_t size_; - PhysicalCudaMemory(CUmemGenericAllocationHandle memHandle, T* devicePtr, size_t size) - : memHandle_(memHandle), devicePtr_(devicePtr), size_(size) {} -}; - namespace detail { /// A wrapper of cudaMalloc that sets the allocated memory to zero. @@ -79,46 +84,38 @@ T* cudaCalloc(size_t nelem) { return ptr; } +#if (CUDA_NVLS_SUPPORTED) template -PhysicalCudaMemory* cudaPhysicalCalloc(size_t nelem, size_t gran) { +T* cudaPhysicalCalloc(size_t nelems, size_t gran) { AvoidCudaGraphCaptureGuard cgcGuard; - int deviceId = -1; + CUdevice currentDevice; MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId)); + MSCCLPP_CUTHROW(cuDeviceGet(¤tDevice, deviceId)); CUmemAllocationProp prop = {}; prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - prop.location.id = deviceId; -#if defined(__HIP_PLATFORM_AMD__) - // TODO: revisit when HIP fixes this typo in the field name - prop.requestedHandleType = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; -#else - prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; -#endif + prop.requestedHandleTypes = + (CUmemAllocationHandleType)(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR | CU_MEM_HANDLE_TYPE_FABRIC); + prop.location.id = currentDevice; - CUmemGenericAllocationHandle memHandle; - size_t bufferSize = sizeof(T) * nelem; // allocate physical memory - MSCCLPP_CUTHROW(cuMemCreate(&memHandle, bufferSize, &prop, 0 /*flags*/)); - - CUmemAccessDesc accessDesc = {}; - accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - accessDesc.location.id = deviceId; - accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + CUmemGenericAllocationHandle memHandle; + size_t nbytes = (nelems * sizeof(T) + gran - 1) / gran * gran; + MSCCLPP_CUTHROW(cuMemCreate(&memHandle, nbytes, &prop, 0 /*flags*/)); T* devicePtr = nullptr; - // Map the device pointer - MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&devicePtr, bufferSize, gran, 0U, 0)); - MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)devicePtr, bufferSize, 0, memHandle, 0)); - MSCCLPP_CUTHROW(cuMemSetAccess((CUdeviceptr)devicePtr, bufferSize, &accessDesc, 1)); + MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&devicePtr, nbytes, gran, 0U, 0)); + MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)devicePtr, nbytes, 0, memHandle, 0)); + setReadWriteMemoryAccess(devicePtr, nbytes); CudaStreamWithFlags stream(cudaStreamNonBlocking); - MSCCLPP_CUDATHROW(cudaMemsetAsync(devicePtr, 0, bufferSize, stream)); - + MSCCLPP_CUDATHROW(cudaMemsetAsync(devicePtr, 0, nbytes, stream)); MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream)); - return new PhysicalCudaMemory(memHandle, devicePtr, bufferSize); + return devicePtr; } +#endif template T* cudaExtCalloc(size_t nelem) { @@ -206,11 +203,15 @@ struct CudaDeleter { template struct CudaPhysicalDeleter { static_assert(!std::is_array_v, "T must not be an array"); - void operator()(PhysicalCudaMemory* ptr) { + void operator()(T* ptr) { AvoidCudaGraphCaptureGuard cgcGuard; - MSCCLPP_CUTHROW(cuMemUnmap((CUdeviceptr)ptr->devicePtr_, ptr->size_)); - MSCCLPP_CUTHROW(cuMemAddressFree((CUdeviceptr)ptr->devicePtr_, ptr->size_)); - MSCCLPP_CUTHROW(cuMemRelease(ptr->memHandle_)); + CUmemGenericAllocationHandle handle; + size_t size = 0; + MSCCLPP_CUTHROW(cuMemRetainAllocationHandle(&handle, ptr)); + MSCCLPP_CUTHROW(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); + MSCCLPP_CUTHROW(cuMemUnmap((CUdeviceptr)ptr, size)); + MSCCLPP_CUTHROW(cuMemRelease(handle)); + MSCCLPP_CUTHROW(cuMemAddressFree((CUdeviceptr)ptr, size)); } }; @@ -234,16 +235,46 @@ std::shared_ptr allocSharedCuda(size_t count = 1) { return detail::safeAlloc, CudaDeleter, std::shared_ptr>(count); } -/// Allocated physical memory on the device and returns a memory handle along with a memory handle for it. -/// The deallocation only happens PhysicalCudaMemory goes out of scope. +#if (CUDA_NVLS_SUPPORTED) +static inline size_t getMulticastGranularity(size_t size, CUmulticastGranularity_flags granFlag) { + size_t gran = 0; + int numDevices = 0; + MSCCLPP_CUDATHROW(cudaGetDeviceCount(&numDevices)); + + CUmulticastObjectProp prop = {}; + prop.size = size; + // This is a dummy value, it might affect the granularity in the future + prop.numDevices = numDevices; + prop.handleTypes = (CUmemAllocationHandleType)(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR | CU_MEM_HANDLE_TYPE_FABRIC); + prop.flags = 0; + MSCCLPP_CUTHROW(cuMulticastGetGranularity(&gran, &prop, granFlag)); + return gran; +} +#endif + +/// Allocates physical memory on the device and returns a std::shared_ptr to it. The memory is zeroed out. /// @tparam T Type of each element in the allocated memory. /// @param count Number of elements to allocate. /// @param gran the granularity of the allocation. -/// @return A std::shared_ptr to the memory handle and a device pointer for that memory. +/// @return A std::shared_ptr to the allocated memory. template -std::shared_ptr> allocSharedPhysicalCuda(size_t count, size_t gran) { - return detail::safeAlloc, detail::cudaPhysicalCalloc, CudaPhysicalDeleter, - std::shared_ptr>>(count, gran); +std::shared_ptr allocSharedPhysicalCuda([[maybe_unused]] size_t count, [[maybe_unused]] size_t gran = 0) { +#if (CUDA_NVLS_SUPPORTED) + if (!isNvlsSupported()) { + throw Error("Only support GPU with NVLS support", ErrorCode::InvalidUsage); + } + if (count == 0) { + return nullptr; + } + + if (gran == 0) { + gran = getMulticastGranularity(count * sizeof(T), CU_MULTICAST_GRANULARITY_RECOMMENDED); + } + size_t nelems = ((count * sizeof(T) + gran - 1) / gran * gran) / sizeof(T); + return detail::safeAlloc, CudaPhysicalDeleter, std::shared_ptr>(nelems, gran); +#else + throw Error("Only support GPU with Fabric support", ErrorCode::InvalidUsage); +#endif } /// Allocates memory on the device and returns a std::shared_ptr to it. The memory is zeroed out. @@ -269,18 +300,6 @@ UniqueCudaPtr allocUniqueCuda(size_t count = 1) { return detail::safeAlloc, CudaDeleter, UniqueCudaPtr>(count); } -/// Allocated physical memory on the device and returns a memory handle along with a virtual memory handle for it. -/// The memory is zeroed out. -/// @tparam T Type of each element in the allocated memory. -/// @param count Number of elements to allocate. -/// @param gran the granularity of the allocation. -/// @return A std::unique_ptr to the memory handle and a device pointer for that memory. -template -std::unique_ptr> allocUniquePhysicalCuda(size_t count, size_t gran) { - return detail::safeAlloc, detail::cudaPhysicalCalloc, CudaPhysicalDeleter, - std::unique_ptr, CudaDeleter>>>(count, gran); -} - /// Allocates memory on the device and returns a std::unique_ptr to it. The memory is zeroed out. /// @tparam T Type of each element in the allocated memory. /// @param count Number of elements to allocate. @@ -349,6 +368,32 @@ UniqueCudaHostPtr makeUniqueCudaHost(size_t count) { return ptr; } +/// Allocated physical memory on the device and returns a memory handle along with a virtual memory handle for it. +/// The memory is zeroed out. +/// @tparam T Type of each element in the allocated memory. +/// @param count Number of elements to allocate. +/// @param gran the granularity of the allocation. +/// @return A std::unique_ptr to the allocated memory. +template +std::unique_ptr allocUniquePhysicalCuda([[maybe_unused]] size_t count, [[maybe_unused]] size_t gran = 0) { +#if (CUDA_NVLS_SUPPORTED) + if (!isNvlsSupported()) { + throw Error("Only suupport GPU with NVLS support", ErrorCode::InvalidUsage); + } + if (count == 0) { + return nullptr; + } + + if (gran == 0) { + gran = getMulticastGranularity(count * sizeof(T), CU_MULTICAST_GRANULARITY_RECOMMENDED); + } + return detail::safeAlloc, CudaPhysicalDeleter, + std::unique_ptr, CudaDeleter>>>(count, gran); +#else + throw Error("Only support GPU with Fabric support", ErrorCode::InvalidUsage); +#endif +} + /// Asynchronous cudaMemcpy without capture into a CUDA graph. /// @tparam T Type of each element in the allocated memory. /// @param dst Destination pointer. diff --git a/include/mscclpp/nvls.hpp b/include/mscclpp/nvls.hpp index 4acc040e8..36ad614ba 100644 --- a/include/mscclpp/nvls.hpp +++ b/include/mscclpp/nvls.hpp @@ -25,26 +25,26 @@ class NvlsConnection { struct DeviceMulticastPointer { private: - std::shared_ptr> deviceMem_; + void* devicePtr_; std::shared_ptr mcPtr_; size_t bufferSize_; public: using DeviceHandle = DeviceMulticastPointerDeviceHandle; - DeviceMulticastPointer(std::shared_ptr> deviceMem, std::shared_ptr mcPtr, - size_t bufferSize) - : deviceMem_(deviceMem), mcPtr_(mcPtr), bufferSize_(bufferSize) {} + DeviceMulticastPointer(void* devicePtr, std::shared_ptr mcPtr, size_t bufferSize) + : devicePtr_(devicePtr), mcPtr_(mcPtr), bufferSize_(bufferSize) {} DeviceHandle deviceHandle(); - char* getDevicePtr(); + void* getDevicePtr(); friend class NvlsConnection; }; - std::shared_ptr allocateAndBindCuda(size_t size); - - /// The \p handle to the allocation (its lifetime is managed by the caller) - /// and the \p size of the allocation. - std::shared_ptr bindAllocatedCuda(CUmemGenericAllocationHandle memHandle, size_t size); + /// @brief bind the allocated memory via @ref mscclpp::allocSharedPhysicalCuda to the multicast handle. The behavior + /// is undefined if the devicePtr is not allocated by @ref mscclpp::allocSharedPhysicalCuda. + /// @param devicePtr + /// @param size + /// @return DeviceMulticastPointer with devicePtr, mcPtr and bufferSize + DeviceMulticastPointer bindAllocatedMemory(CUdeviceptr devicePtr, size_t size); size_t getMultiCastMinGranularity(); diff --git a/include/mscclpp/nvls_device.hpp b/include/mscclpp/nvls_device.hpp index 402a65218..622a1a597 100644 --- a/include/mscclpp/nvls_device.hpp +++ b/include/mscclpp/nvls_device.hpp @@ -11,6 +11,8 @@ #include #endif // defined(MSCCLPP_DEVICE_CUDA) +#include + #include "device.hpp" namespace mscclpp { @@ -27,7 +29,11 @@ struct DeviceMulticastPointerDeviceHandle { #if defined(MSCCLPP_DEVICE_CUDA) template MSCCLPP_DEVICE_INLINE static void multimemLoadReduce(TValue& val, T* ptr) { - if constexpr (std::is_same_v && std::is_same_v) { + if constexpr (std::is_same_v && std::is_same_v) { + asm("multimem.ld_reduce.relaxed.sys.global.add.s32 %0, [%1];" : "=r"(val) : "l"(ptr) : "memory"); + } else if constexpr (std::is_same_v && std::is_same_v) { + asm("multimem.ld_reduce.relaxed.sys.global.add.u32 %0, [%1];" : "=r"(val) : "l"(ptr) : "memory"); + } else if constexpr (std::is_same_v && std::is_same_v) { asm("multimem.ld_reduce.relaxed.sys.global.add.v4.f32 {%0,%1,%2,%3}, [%4];" : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) : "l"(ptr) @@ -51,6 +57,13 @@ struct DeviceMulticastPointerDeviceHandle { : "memory"); } else if constexpr (std::is_same_v && std::is_same_v) { asm("multimem.ld_reduce.relaxed.sys.global.add.f16x2 {%0}, [%1];" : "=r"(val.x) : "l"(ptr) : "memory"); + } else if constexpr (std::is_same_v && std::is_same_v) { + asm("multimem.ld_reduce.relaxed.sys.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];" + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) + : "l"(ptr) + : "memory"); + } else if constexpr (std::is_same_v && std::is_same_v) { + asm("multimem.ld_reduce.relaxed.sys.global.add.bf16x2 {%0}, [%1];" : "=r"(val.x) : "l"(ptr) : "memory"); } else { static_assert(dependentFalse, "Not supported type"); } @@ -58,7 +71,11 @@ struct DeviceMulticastPointerDeviceHandle { template MSCCLPP_DEVICE_INLINE static void multimemStore(const TValue& val, T* ptr) { - if constexpr (std::is_same_v && std::is_same_v) { + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile("multimem.st.relaxed.sys.global.s32 [%0], %1;" ::"l"(ptr), "r"(val) : "memory"); + } else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile("multimem.st.relaxed.sys.global.u32 [%0], %1;" ::"l"(ptr), "r"(val) : "memory"); + } else if constexpr (std::is_same_v && std::is_same_v) { asm volatile("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w) : "memory"); @@ -76,6 +93,12 @@ struct DeviceMulticastPointerDeviceHandle { : "memory"); } else if constexpr (std::is_same_v && std::is_same_v) { asm volatile("multimem.st.relaxed.sys.global.f16x2 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory"); + } else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile("multimem.st.relaxed.sys.global.v4.bf16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y), + "r"(val.z), "r"(val.w) + : "memory"); + } else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile("multimem.st.relaxed.sys.global.bf16x2 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory"); } else { static_assert(dependentFalse, "Not supported type"); } diff --git a/python/mscclpp/__init__.py b/python/mscclpp/__init__.py index ce9917ab5..1b79fc130 100644 --- a/python/mscclpp/__init__.py +++ b/python/mscclpp/__init__.py @@ -26,6 +26,7 @@ PacketType, version, is_nvls_supported, + alloc_shared_physical_cuda, npkit, ) diff --git a/python/mscclpp/core_py.cpp b/python/mscclpp/core_py.cpp index 257de400a..95048f18b 100644 --- a/python/mscclpp/core_py.cpp +++ b/python/mscclpp/core_py.cpp @@ -23,6 +23,7 @@ extern void register_numa(nb::module_& m); extern void register_nvls(nb::module_& m); extern void register_executor(nb::module_& m); extern void register_npkit(nb::module_& m); +extern void register_gpu_utils(nb::module_& m); template void def_nonblocking_future(nb::handle& m, const std::string& typestr) { @@ -194,4 +195,5 @@ NB_MODULE(_mscclpp, m) { register_nvls(m); register_executor(m); register_npkit(m); + register_gpu_utils(m); } diff --git a/python/mscclpp/gpu_utils_py.cpp b/python/mscclpp/gpu_utils_py.cpp new file mode 100644 index 000000000..32c578fb7 --- /dev/null +++ b/python/mscclpp/gpu_utils_py.cpp @@ -0,0 +1,30 @@ +#include +#include + +// #include +#include +#include + +namespace nb = nanobind; +using namespace mscclpp; + +class PyCudaMemory { + public: + PyCudaMemory(size_t size) : size_(size) { ptr_ = allocSharedPhysicalCuda(size); } + + uintptr_t getPtr() const { return (uintptr_t)(ptr_.get()); } + size_t size() const { return size_; } + + private: + std::shared_ptr ptr_; + size_t size_; +}; + +void register_gpu_utils(nb::module_& m) { + nb::class_(m, "PyCudaMemory") + .def(nb::init(), nb::arg("size")) + .def("get_ptr", &PyCudaMemory::getPtr, "Get the raw pointer") + .def("size", &PyCudaMemory::size, "Get the size of the allocated memory"); + m.def( + "alloc_shared_physical_cuda", [](size_t size) { return std::make_shared(size); }, nb::arg("size")); +} diff --git a/python/mscclpp/nvls_py.cpp b/python/mscclpp/nvls_py.cpp index 819a7c6b0..91b966bd8 100644 --- a/python/mscclpp/nvls_py.cpp +++ b/python/mscclpp/nvls_py.cpp @@ -30,7 +30,7 @@ void register_nvls(nb::module_& m) { }); nb::class_(m, "NvlsConnection") - .def("allocate_bind_memory", &NvlsConnection::allocateAndBindCuda) + .def("bind_allocated_memory", &NvlsConnection::bindAllocatedMemory, nb::arg("devicePtr"), nb::arg("size")) .def("get_multicast_min_granularity", &NvlsConnection::getMultiCastMinGranularity); m.def("connect_nvls_collective", &connectNvlsCollective, nb::arg("communicator"), nb::arg("allRanks"), diff --git a/python/mscclpp_benchmark/mscclpp_op.py b/python/mscclpp_benchmark/mscclpp_op.py index 706107bef..88840a743 100644 --- a/python/mscclpp_benchmark/mscclpp_op.py +++ b/python/mscclpp_benchmark/mscclpp_op.py @@ -1,7 +1,7 @@ import os import cupy as cp import ctypes -from mscclpp import Transport, ProxyService, SmDevice2DeviceSemaphore +from mscclpp import Transport, ProxyService, SmDevice2DeviceSemaphore, alloc_shared_physical_cuda import mscclpp.comm as mscclpp_comm from mscclpp.utils import KernelBuilder, pack @@ -443,12 +443,15 @@ def __init__( self.nvls_connection = group.make_connection(all_ranks, Transport.Nvls) min_gran = self.nvls_connection.get_multicast_min_granularity() aligned_buffer_size = int(((buffer_size + min_gran - 1) // min_gran) * min_gran) - self.nvls_mem_handle = self.nvls_connection.allocate_bind_memory( - aligned_buffer_size + buffer_raw = alloc_shared_physical_cuda(aligned_buffer_size) + self.nvls_mem_handle = self.nvls_connection.bind_allocated_memory( + buffer_raw.get_ptr(), aligned_buffer_size ) # just using recommended size for now self.memory_ptr = self.nvls_mem_handle.get_device_ptr() - self.cp_memory_ptr = cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(self.memory_ptr, aligned_buffer_size, None), 0) + self.cp_memory_ptr = cp.cuda.MemoryPointer( + cp.cuda.UnownedMemory(self.memory_ptr, aligned_buffer_size, buffer_raw), 0 + ) self.memory = cp.ndarray(nelem, memory_dtype, self.cp_memory_ptr) # create a sm_channel for each remote neighbor diff --git a/python/test/executor_test.py b/python/test/executor_test.py index d0cda18a3..60cf36b95 100644 --- a/python/test/executor_test.py +++ b/python/test/executor_test.py @@ -8,6 +8,8 @@ ExecutionPlan, PacketType, npkit, + alloc_shared_physical_cuda, + is_nvls_supported, ) import mscclpp.comm as mscclpp_comm from mscclpp.utils import KernelBuilder, pack @@ -125,6 +127,18 @@ def dtype_to_mscclpp_dtype(dtype): raise ValueError(f"Unknown data type: {dtype}") +def allocate_buffer(nelems, dtype): + if is_nvls_supported: + buffer_raw = alloc_shared_physical_cuda(nelems * cp.dtype(dtype).itemsize) + buffer_ptr = cp.cuda.MemoryPointer( + cp.cuda.UnownedMemory(buffer_raw.get_ptr(), buffer_raw.size(), buffer_raw), 0 + ) + buffer = cp.ndarray(nelems, dtype=dtype, memptr=buffer_ptr) + return buffer + else: + return cp.zeros(nelems, dtype=dtype) + + def build_bufs( execution_plan_name: str, size: int, @@ -144,14 +158,14 @@ def build_bufs( nelems_input = nelems nelems_output = nelems - result_buf = cp.zeros(nelems_output, dtype=dtype) + result_buf = allocate_buffer(nelems_output, dtype=dtype) if in_place: if "allgather" in execution_plan_name: input_buf = cp.split(result_buf, num_ranks)[rank] else: input_buf = result_buf else: - input_buf = cp.zeros(nelems_input, dtype=dtype) + input_buf = allocate_buffer(nelems_input, dtype=dtype) test_buf = cp.zeros(nelems_output, dtype=dtype) return input_buf, result_buf, test_buf diff --git a/src/executor/execution_plan.cc b/src/executor/execution_plan.cc index 827623eab..49ceddf0a 100644 --- a/src/executor/execution_plan.cc +++ b/src/executor/execution_plan.cc @@ -52,6 +52,8 @@ auto getOpType = [](const std::string& str) { return mscclpp::OperationType::TRANSFORM_TO_PACKET; } else if (str == "rpkt") { return mscclpp::OperationType::REDUCE_PACKET; + } else if (str == "glres") { + return mscclpp::OperationType::MULTI_LOAD_REDUCE_STORE; } else { throw mscclpp::Error("Invalid operation type", mscclpp::ErrorCode::ExecutorError); } @@ -76,11 +78,15 @@ auto convertToChannelType = [](const std::string& str) { return mscclpp::ChannelType::PROXY; } else if (str == "none") { return mscclpp::ChannelType::NONE; + } else if (str == "nvls") { + return mscclpp::ChannelType::NVLS; } else { throw mscclpp::Error("Invalid channel type", mscclpp::ErrorCode::ExecutorError); } }; +std::set groupChannelType{mscclpp::ChannelType::NVLS}; + } // namespace namespace mscclpp { @@ -100,7 +106,7 @@ std::vector ExecutionPlan::Impl::getChannelInfos(int rank, BufferTy } std::vector ExecutionPlan::Impl::getChannelInfosByDstRank(int rank, BufferType bufferType) const { - auto pred = [rank, bufferType](const ChannelInfo& info) { return info.dstBufferType == bufferType; }; + auto pred = [bufferType](const ChannelInfo& info) { return info.dstBufferType == bufferType; }; return filter(this->channelInfosByDstRank.at(rank), pred); } @@ -126,6 +132,8 @@ std::vector ExecutionPlan::Impl::getUnpairedChannelInfos(int rank, return unpaired; } +std::vector ExecutionPlan::Impl::getNvlsInfos(int rank) const { return this->nvlsInfos.at(rank); } + std::vector ExecutionPlan::Impl::getConnectedPeers(int rank) const { std::set peers; for (const auto& info : this->channelInfos.at(rank)) { @@ -181,6 +189,8 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize, if (protocol == "LL") { this->isUsingPacket = true; } + this->inputSize = inputSize; + this->outputSize = outputSize; this->nThreadsPerBlock = obj.value("num_threads_per_block", 1024); const auto& gpus = obj["gpus"]; @@ -192,9 +202,6 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize, this->chunkGroups[rank] = gpu["chunkGroups"]; } this->setupChannels(gpus); - - this->inputSize = inputSize; - this->outputSize = outputSize; this->setupOperations(gpus, contsSrcOffset, constDstOffset); } @@ -224,15 +231,24 @@ void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t output this->setupOperations(gpus, contsSrcOffset, constDstOffset); } -// Construct the channel info. Step 1. Flatten SM and PROXY channels into separate vectors. -// Step 2. For each threadblock, construct a vector of channel indexes and keys. -void ExecutionPlan::Impl::setupChannels(const json& gpus) { - using mapKey = std::tuple; - std::map> chanConnectedPeersMap; - for (const auto& gpu : gpus) { - int rank = gpu["id"]; - std::vector channelInfos; - for (const auto& channel : gpu["channels"]) { +void ExecutionPlan::Impl::parseChannels( + const json& gpu, std::vector& channelInfos, std::vector& nvlsInfos, + std::map, std::vector>& chanConnectedPeersMap, int rank) { + for (const auto& channel : gpu["channels"]) { + ChannelType chanType = convertToChannelType(channel["type"]); + + if (chanType == ChannelType::NVLS) { + NvlsInfo info; + info.bufferType = convertToBufferType(channel["buff"]); + for (const auto& group : channel["rankGroups"]) { + info.bufferSize = (int)group["size"] * this->getUpperBoundChunkSize(rank, this->inputSize, this->outputSize); + info.ranks.clear(); + for (int rank : group["ranks"]) { + info.ranks.push_back(rank); + } + nvlsInfos.push_back(info); + } + } else { ChannelInfo info; info.srcBufferType = convertToBufferType(channel["srcbuff"]); info.dstBufferType = convertToBufferType(channel["dstbuff"]); @@ -244,7 +260,21 @@ void ExecutionPlan::Impl::setupChannels(const json& gpus) { } channelInfos.push_back(info); } + } +} + +// Construct the channel info. Step 1. Flatten SM and PROXY channels into separate vectors. +// Step 2. For each threadblock, construct a vector of channel indexes and keys. +void ExecutionPlan::Impl::setupChannels(const json& gpus) { + using mapKey = std::tuple; + std::map> chanConnectedPeersMap; + for (const auto& gpu : gpus) { + int rank = gpu["id"]; + std::vector channelInfos; + std::vector nvlsInfos; + this->parseChannels(gpu, channelInfos, nvlsInfos, chanConnectedPeersMap, rank); this->channelInfos[rank] = channelInfos; + this->nvlsInfos[rank] = nvlsInfos; } for (const auto& [key, connectedFrom] : chanConnectedPeersMap) { @@ -260,21 +290,30 @@ void ExecutionPlan::Impl::setupChannels(const json& gpus) { // setup threadblockChannelMap for (const auto& gpu : gpus) { int rank = gpu["id"]; - auto channelTypes = {ChannelType::SM, ChannelType::PROXY}; + auto channelTypes = {ChannelType::SM, ChannelType::PROXY, ChannelType::NVLS}; std::unordered_map> channelMap; for (auto channelType : channelTypes) { const std::vector channelInfos = this->getChannelInfos(rank, channelType); int index = 0; - for (const auto& info : channelInfos) { - ChannelKey key = {info.srcBufferType, info.dstBufferType, info.channelType}; - for (size_t i = 0; i < info.connectedPeers.size(); i++) { + if (channelType == ChannelType::NVLS) { + const std::vector nvlsInfos = this->getNvlsInfos(rank); + for (const auto& info : nvlsInfos) { + ChannelKey key = {info.bufferType, info.bufferType, ChannelType::NVLS}; channelMap[key].push_back(index++); } + } else { + for (const auto& info : channelInfos) { + ChannelKey key = {info.srcBufferType, info.dstBufferType, info.channelType}; + for (size_t i = 0; i < info.connectedPeers.size(); i++) { + channelMap[key].push_back(index++); + } + } } } int nthreadblocks = gpu["threadblocks"].size(); this->threadblockSMChannelMap[rank].resize(nthreadblocks); this->threadblockProxyChannelMap[rank].resize(nthreadblocks); + this->threadblockNvlsChannelMap[rank].resize(nthreadblocks); for (const auto& threadblock : gpu["threadblocks"]) { for (const auto& channel : threadblock["channels"]) { ChannelType channelType = convertToChannelType(channel["ctype"]); @@ -284,6 +323,8 @@ void ExecutionPlan::Impl::setupChannels(const json& gpus) { this->threadblockSMChannelMap[rank][threadblock["id"]].emplace_back(channelMap[key][id], key); } else if (channelType == ChannelType::PROXY) { this->threadblockProxyChannelMap[rank][threadblock["id"]].emplace_back(channelMap[key][id], key); + } else if (channelType == ChannelType::NVLS) { + this->threadblockNvlsChannelMap[rank][threadblock["id"]].emplace_back(channelMap[key][id], key); } } } @@ -314,6 +355,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse int threadblockId = threadblock["id"]; const auto& smChannels = this->threadblockSMChannelMap[rank][threadblockId]; const auto& proxyChannels = this->threadblockProxyChannelMap[rank][threadblockId]; + const auto& nvlsChannels = this->threadblockNvlsChannelMap[rank][threadblockId]; for (size_t i = 0; i < smChannels.size(); i++) { const auto& [_, key] = smChannels[i]; channelIndexes[key].push_back(i); @@ -322,6 +364,10 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse const auto& [_, key] = proxyChannels[i]; channelIndexes[key].push_back(i); } + for (size_t i = 0; i < nvlsChannels.size(); i++) { + const auto& [_, key] = nvlsChannels[i]; + channelIndexes[key].push_back(i); + } for (const auto& op : threadblock["ops"]) { Operation operation = {}; std::vector chunkIndexes; @@ -330,17 +376,24 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse operation.channelType = convertToChannelType(op["ctype"]); } if (op.contains("i_cids")) { - operation.nInputs = op["i_cids"].size(); - for (int i = 0; i < operation.nInputs; i++) { - BufferType srcBufferType = convertToBufferType(op["i_buff"]["src"]); - BufferType dstBufferType = convertToBufferType(op["i_buff"]["dst"]); - // Get the relevant channel index in rank channelInfos - operation.inputChannelIndexes[i] = - channelIndexes[{srcBufferType, dstBufferType, operation.channelType}][op["i_cids"][i]["id"]]; - operation.inputOffsets[i] = - this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["i_cids"][i]["off"]) + - getConstOffset(srcBufferType); - chunkIndexes.push_back((uint32_t)op["i_cids"][i]["off"]); + if (operation.channelType == mscclpp::ChannelType::NVLS) { + BufferType srcBufferType = convertToBufferType(op["srcbuff"]); + operation.nvlsInputIndex = + channelIndexes[{srcBufferType, srcBufferType, ChannelType::NVLS}][op["i_cids"][0]["id"]]; + chunkIndexes.push_back((uint32_t)op["srcoff"]); + } else { + operation.nInputs = op["i_cids"].size(); + for (int i = 0; i < operation.nInputs; i++) { + BufferType srcBufferType = convertToBufferType(op["i_buff"]["src"]); + BufferType dstBufferType = convertToBufferType(op["i_buff"]["dst"]); + // Get the relevant channel index in rank channelInfos + operation.inputChannelIndexes[i] = + channelIndexes[{srcBufferType, dstBufferType, operation.channelType}][op["i_cids"][i]["id"]]; + operation.inputOffsets[i] = + this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["i_cids"][i]["off"]) + + getConstOffset(srcBufferType); + chunkIndexes.push_back((uint32_t)op["i_cids"][i]["off"]); + } } } // will have either srcs or i_cids @@ -357,14 +410,21 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse if (op.contains("o_cids")) { operation.nOutputs = op["o_cids"].size(); for (int i = 0; i < operation.nOutputs; i++) { - BufferType srcBufferType = convertToBufferType(op["o_buff"]["src"]); - BufferType dstBufferType = convertToBufferType(op["o_buff"]["dst"]); - operation.outputChannelIndexes[i] = - channelIndexes[{srcBufferType, dstBufferType, operation.channelType}][op["o_cids"][i]["id"]]; - operation.outputOffsets[i] = - this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["o_cids"][i]["off"]) + - getConstOffset(dstBufferType); - chunkIndexes.push_back((uint32_t)op["o_cids"][i]["off"]); + if (operation.channelType == mscclpp::ChannelType::NVLS) { + BufferType dstBufferType = convertToBufferType(op["dstbuff"]); + operation.nvlsInputIndex = + channelIndexes[{dstBufferType, dstBufferType, ChannelType::NVLS}][op["o_cids"][0]["id"]]; + chunkIndexes.push_back((uint32_t)op["dstoff"]); + } else { + BufferType srcBufferType = convertToBufferType(op["o_buff"]["src"]); + BufferType dstBufferType = convertToBufferType(op["o_buff"]["dst"]); + operation.outputChannelIndexes[i] = + channelIndexes[{srcBufferType, dstBufferType, operation.channelType}][op["o_cids"][i]["id"]]; + operation.outputOffsets[i] = + this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["o_cids"][i]["off"]) + + getConstOffset(dstBufferType); + chunkIndexes.push_back((uint32_t)op["o_cids"][i]["off"]); + } } } // will have either dsts or o_cids @@ -460,11 +520,19 @@ size_t ExecutionPlan::Impl::getNChunkSize(int rank, size_t inputSize, size_t out return nChunkSize; } +size_t ExecutionPlan::Impl::getUpperBoundChunkSize(int rank, size_t inputSize, size_t outputSize) const { + auto sizePerRank = calcSizePerRank(rank, inputSize, outputSize); + uint32_t nChunks = sizePerRank.second; + return (sizePerRank.first + nChunks - 1) / nChunks; +} + void ExecutionPlan::Impl::reset() { this->operations.clear(); this->channelInfos.clear(); + this->nvlsInfos.clear(); this->threadblockSMChannelMap.clear(); this->threadblockProxyChannelMap.clear(); + this->threadblockNvlsChannelMap.clear(); this->inputChunks.clear(); this->outputChunks.clear(); this->scratchChunks.clear(); diff --git a/src/executor/executor.cc b/src/executor/executor.cc index 1fcb61865..ae34fa1bb 100644 --- a/src/executor/executor.cc +++ b/src/executor/executor.cc @@ -2,6 +2,7 @@ // Licensed under the MIT license. #include +#include #include #include #include @@ -23,6 +24,19 @@ struct ExecutionContextKey { } }; +void* getBuffer(BufferType type, void* sendbuff, void* recvbuff, void* scratch) { + switch (type) { + case BufferType::INPUT: + return sendbuff; + case BufferType::OUTPUT: + return recvbuff; + case BufferType::SCRATCH: + return scratch; + default: + throw Error("Invalid buffer type", ErrorCode::ExecutorError); + } +}; + struct DeviceExecutionPlanKey { size_t inputMessageSize; size_t outputMessageSize; @@ -97,11 +111,13 @@ namespace mscclpp { struct ExecutionContext { std::shared_ptr proxyService; std::unordered_map> connections; + std::vector> nvlsConnections; std::unordered_map, mscclpp::RegisteredMemory> registeredMemories; std::vector> smSemaphores; std::vector proxySemaphores; std::vector smChannels; std::vector proxyChannels; + std::vector nvlsChannels; std::unordered_map> deviceExecutionPlans; std::unordered_map> deviceExecutionPlansBuffers; std::shared_ptr scratchBuffer; @@ -152,7 +168,12 @@ struct Executor::Impl { ExecutionContext context; size_t scratchBufferSize = plan.impl_->getScratchBufferSize(rank, sendBufferSize, recvBufferSize); - std::shared_ptr scratchBuffer = allocExtSharedCuda(scratchBufferSize); + std::shared_ptr scratchBuffer; + if (isNvlsSupported()) { + scratchBuffer = allocSharedPhysicalCuda(scratchBufferSize); + } else { + scratchBuffer = allocExtSharedCuda(scratchBufferSize); + } context.scratchBuffer = scratchBuffer; context.scratchBufferSize = scratchBufferSize; context.proxyService = std::make_shared(); @@ -160,6 +181,7 @@ struct Executor::Impl { this->setupConnections(context, rank, plan); this->setupRegisteredMemories(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan); this->setupChannels(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan); + this->setupNvlsChannels(context, sendbuff, recvbuff, rank, plan); this->setupDeviceExecutionPlan(context, devicePlanKey, rank, plan); context.deviceExecutionPlansBuffers[devicePlanKey] = allocExtSharedCuda(context.deviceExecutionPlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan)); @@ -202,6 +224,13 @@ struct Executor::Impl { for (size_t i = 0; i < connectionFutures.size(); i++) { context.connections[connectedPeers[i]] = connectionFutures[i].get(); } + + std::vector nvlsInfos = plan.impl_->getNvlsInfos(rank); + for (const NvlsInfo& info : nvlsInfos) { + std::shared_ptr nvlsConnection = + mscclpp::connectNvlsCollective(this->comm, info.ranks, info.bufferSize); + context.nvlsConnections.push_back(nvlsConnection); + } } void setupRegisteredMemories(ExecutionContext& context, void* sendbuff, void* recvbuff, size_t sendBufferSize, @@ -284,18 +313,6 @@ struct Executor::Impl { context.smSemaphores = std::move(smSemaphores); context.proxySemaphores = std::move(proxySemaphores); - auto getBuffer = [&](BufferType type) { - switch (type) { - case BufferType::INPUT: - return sendbuff; - case BufferType::OUTPUT: - return recvbuff; - case BufferType::SCRATCH: - return (void*)context.scratchBuffer.get(); - default: - throw Error("Invalid buffer type", ErrorCode::ExecutorError); - } - }; auto getBufferSize = [&](BufferType type) { switch (type) { case BufferType::INPUT: @@ -313,7 +330,7 @@ struct Executor::Impl { std::vector channelInfos = plan.impl_->getChannelInfos(rank, channelType); int index = 0; for (ChannelInfo& info : channelInfos) { - void* src = getBuffer(info.srcBufferType); + void* src = getBuffer(info.srcBufferType, sendbuff, recvbuff, context.scratchBuffer.get()); size_t bufferSize = getBufferSize(info.srcBufferType); TransportFlags transport = getTransportFlags(channelInfos, rank); RegisteredMemory localMemory = this->comm->registerMemory(src, bufferSize, transport); @@ -332,6 +349,19 @@ struct Executor::Impl { } } + void setupNvlsChannels(ExecutionContext& context, void* sendbuff, void* recvbuff, int rank, + const ExecutionPlan& plan) { + std::vector nvlsInfos = plan.impl_->getNvlsInfos(rank); + for (size_t i = 0; i < nvlsInfos.size(); i++) { + std::shared_ptr nvlsConnection = context.nvlsConnections[i]; + NvlsInfo info = nvlsInfos[i]; + void* buffer = getBuffer(info.bufferType, sendbuff, recvbuff, context.scratchBuffer.get()); + NvlsConnection::DeviceMulticastPointer deviceMulticastPointer = + nvlsConnection->bindAllocatedMemory((CUdeviceptr)buffer, info.bufferSize); + context.nvlsChannels.push_back(deviceMulticastPointer); + } + } + void setupDeviceExecutionPlan(ExecutionContext& context, const DeviceExecutionPlanKey& key, int rank, const ExecutionPlan& plan) { std::vector deviceExecutionPlans; @@ -349,6 +379,10 @@ struct Executor::Impl { for (const auto& [index, _] : plan.impl_->threadblockProxyChannelMap.at(rank).at(threadblock)) { deviceExecutionPlan.channels.proxyChannels[chanIndex++] = mscclpp::deviceHandle(context.proxyChannels[index]); } + chanIndex = 0; + for (const auto& [index, _] : plan.impl_->threadblockNvlsChannelMap.at(rank).at(threadblock)) { + deviceExecutionPlan.channels.nvlsChannels[chanIndex++] = mscclpp::deviceHandle(context.nvlsChannels[index]); + } for (size_t i = 0; i < ops.size(); i++) { deviceExecutionPlan.operations[i] = ops[i]; } diff --git a/src/include/execution_common.hpp b/src/include/execution_common.hpp index 99bf36a4f..f4f4fbd8c 100644 --- a/src/include/execution_common.hpp +++ b/src/include/execution_common.hpp @@ -4,6 +4,7 @@ #ifndef MSCCLPP_EXECUTION_COMMON_HPP_ #define MSCCLPP_EXECUTION_COMMON_HPP_ +#include #include #include @@ -24,6 +25,7 @@ enum class ChannelType : uint8_t { NONE, SM, PROXY, + NVLS, }; // NOTE(chhwang): any modification here requires corresponding updates in `tools/npkit/npkit_trace_generator.py`. @@ -46,11 +48,13 @@ enum class OperationType : uint8_t { REDUCE_SEND_PACKET, READ_REDUCE_COPY, READ_REDUCE_COPY_SEND, + MULTI_LOAD_REDUCE_STORE, }; struct Channels { mscclpp::DeviceHandle smChannels[MAX_CHANNEL]; mscclpp::DeviceHandle proxyChannels[MAX_CHANNEL]; + mscclpp::DeviceHandle nvlsChannels[MAX_CHANNEL]; }; struct Operation { @@ -61,12 +65,18 @@ struct Operation { uint8_t nInputs; uint8_t nOutputs; union { + // For ops which require reading from multiple remote sources uint8_t inputChannelIndexes[MAX_CHANNEL_PER_OPERATION]; + // For ops which require reading from multiple local sources BufferType inputBufferType; + uint8_t nvlsInputIndex; }; union { + // For ops which require writing to multiple remote destinations uint8_t outputChannelIndexes[MAX_CHANNEL_PER_OPERATION]; + // For ops which require writing to multiple local destinations BufferType outputBufferType; + uint8_t nvlsOutputIndex; }; uint32_t inputOffsets[MAX_CHANNEL_PER_OPERATION]; uint32_t outputOffsets[MAX_CHANNEL_PER_OPERATION]; @@ -75,12 +85,12 @@ struct Operation { uint32_t size; }; -// total size = 1920 + 6400 + 4 + 4(padding) + 12(align) = 8336 bytes +// total size = 2304 + 6400 + 4 + 12(padding) = 8720 bytes struct __attribute__((aligned(16))) DeviceExecutionPlan { uint8_t nSmChannels; // 1 bytes uint8_t nProxyChannels; // 1 bytes uint16_t nOperations; // 2 bytes - Channels channels; // 1920 bytes + Channels channels; // 2304 bytes Operation operations[MAX_OPERATION]; // 64 * 100 = 6400 bytes }; diff --git a/src/include/execution_kernel.hpp b/src/include/execution_kernel.hpp index 0b64da197..1e9d6ac57 100644 --- a/src/include/execution_kernel.hpp +++ b/src/include/execution_kernel.hpp @@ -15,7 +15,8 @@ #include "execution_common.hpp" #if defined(MSCCLPP_DEVICE_COMPILE) -#include "gpu_data_types.hpp" +#include +#include namespace { template @@ -138,6 +139,34 @@ MSCCLPP_DEVICE_INLINE uint32_t add_vectors<__bfloat16>(uint32_t a, uint32_t b) { return add_vectors_helper<__bfloat162>(a, b); } +template +struct VectorType { + using type = T; + using nvls_type = T; + using nvls_type2 = T; +}; + +template <> +struct VectorType<__half> { + using type = __half2; + using nvls_type = uint4; + using nvls_type2 = uint1; +}; + +template <> +struct VectorType<__bfloat16> { + using type = __bfloat162; + using nvls_type = uint4; + using nvls_type2 = uint1; +}; + +template <> +struct VectorType { + using type = float; + using nvls_type = uint4; + using nvls_type2 = uint1; +}; + } // namespace #endif // defined(MSCCLPP_DEVICE_COMPILE) @@ -401,6 +430,37 @@ MSCCLPP_DEVICE_INLINE void handleCopy(void* dst, void* src, uint32_t dstOffset, Element::copy(dstData, srcData, size, threadIdx.x, blockDim.x); } +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 +template +MSCCLPP_DEVICE_INLINE void handleMultiLoadReduceStore(T* dst, T* src, uint32_t dstOffset, uint32_t srcOffset, + size_t size) { + using vectorType = typename VectorType::type; + using nvlsType = typename VectorType::nvls_type; + // nvls can only handle 4 bytes alignment + assert(size % sizeof(vectorType) == 0); + const size_t nInt4 = size / sizeof(nvlsType); + const size_t srcOffset4 = srcOffset / sizeof(nvlsType); + const size_t dstOffset4 = dstOffset / sizeof(nvlsType); + nvlsType* src4 = (nvlsType*)src; + nvlsType* dst4 = (nvlsType*)dst; + for (size_t idx = threadIdx.x; idx < nInt4; idx += blockDim.x) { + nvlsType val; + DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (vectorType*)(src4 + srcOffset4 + idx)); + DeviceMulticastPointerDeviceHandle::multimemStore(val, (vectorType*)(dst4 + dstOffset4 + idx)); + } + // handle rest of data + size_t processed = nInt4 * sizeof(nvlsType); + using nvlsType2 = typename VectorType::nvls_type2; + const size_t startIdx = (srcOffset + processed) / sizeof(nvlsType2); + const size_t endIdx = (dstOffset + size) / sizeof(nvlsType2); + for (size_t idx = threadIdx.x + startIdx; idx < endIdx; idx += blockDim.x) { + nvlsType2 val; + DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (vectorType*)src + idx); + DeviceMulticastPointerDeviceHandle::multimemStore(val, (vectorType*)dst + idx); + } +} +#endif + template __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* input, T* output, T* scratch, size_t scratchSize, DeviceExecutionPlan* plan, uint32_t flag @@ -433,6 +493,8 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu Operation* operations = localPlan->operations; DeviceHandle* smChannels = localPlan->channels.smChannels; DeviceHandle* proxyChannels = localPlan->channels.proxyChannels; + [[maybe_unused]] DeviceHandle* nvlsChannels = + localPlan->channels.nvlsChannels; #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_CPU) #if defined(MSCCLPP_DEVICE_HIP) @@ -530,6 +592,13 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu handleReduceSend(dst, op.dstOffset, src, op.srcOffset, tmp, op.inputOffsets, smChannels, op.outputChannelIndexes, op.outputOffsets, op.nOutputs, op.size); } +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + else if (op.type == OperationType::MULTI_LOAD_REDUCE_STORE) { + T* dst = (T*)(nvlsChannels[op.nvlsOutputIndex].mcPtr); + T* src = (T*)(nvlsChannels[op.nvlsInputIndex].mcPtr); + handleMultiLoadReduceStore(dst, src, op.dstOffset, op.srcOffset, op.size); + } +#endif #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_EXECUTOR_OP_BASE_EXIT) NpKit::CollectGpuEventShm(NPKIT_EVENT_EXECUTOR_OP_BASE_EXIT + (int)op.type, op.size, 0, NPKIT_GET_GPU_TIMESTAMP(), diff --git a/src/include/execution_plan.hpp b/src/include/execution_plan.hpp index a44962782..07292d748 100644 --- a/src/include/execution_plan.hpp +++ b/src/include/execution_plan.hpp @@ -23,6 +23,12 @@ struct ChannelKey { channelType == other.channelType; } }; + +struct NvlsInfo { + std::vector ranks; + size_t bufferSize; + BufferType bufferType; +}; } // namespace mscclpp namespace std { @@ -63,6 +69,7 @@ struct ExecutionPlan::Impl { std::vector getChannelInfos(int rank, BufferType bufferType) const; std::vector getChannelInfosByDstRank(int rank, BufferType bufferType) const; std::vector getUnpairedChannelInfos(int rank, int worldSize, ChannelType channelType); + std::vector getNvlsInfos(int rank) const; std::vector getConnectedPeers(int rank) const; std::vector getConnectedBufferTypes(int rank) const; size_t getScratchBufferSize(int rank, size_t inputSize, size_t outputSize) const; @@ -86,9 +93,12 @@ struct ExecutionPlan::Impl { std::unordered_map> channelInfos; std::unordered_map> channelInfosByDstRank; std::unordered_map, std::unordered_map> channelCountMap; + // for nvls channels + std::unordered_map> nvlsInfos; // threadblockChannelMap[rank][threadblock] = [channelIndex, channelKey] std::unordered_map>>> threadblockSMChannelMap; std::unordered_map>>> threadblockProxyChannelMap; + std::unordered_map>>> threadblockNvlsChannelMap; std::unordered_map inputChunks; std::unordered_map outputChunks; std::unordered_map scratchChunks; @@ -102,6 +112,13 @@ struct ExecutionPlan::Impl { size_t getOffset(int rank, size_t inputSize, size_t outputSize, uint32_t chunkIndex, uint32_t alignment = 16) const; size_t getNChunkSize(int rank, size_t inputSize, size_t outputSize, uint32_t nChunks, const std::vector offsets) const; + size_t getUpperBoundChunkSize(int rank, size_t inputSize, size_t outputSize) const; + + // helper functions to setup the channels + void parseChannels( + const nlohmann::json& gpu, std::vector& channelInfos, std::vector& nvlsInfos, + std::map, std::vector>& chanConnectedPeersMap, + int rank); }; } // namespace mscclpp diff --git a/src/include/registered_memory.hpp b/src/include/registered_memory.hpp index 11cd30231..2f7727636 100644 --- a/src/include/registered_memory.hpp +++ b/src/include/registered_memory.hpp @@ -27,6 +27,10 @@ struct TransportInfo { const IbMr* ibMr; IbMrInfo ibMrInfo; }; + struct { + char shareableHandle[64]; + size_t offsetFromBase; + }; }; }; @@ -39,6 +43,7 @@ struct RegisteredMemory::Impl { size_t size; uint64_t hostHash; uint64_t pidHash; + bool isCuMemMapAlloc; TransportFlags transports; std::vector transportInfos; diff --git a/src/nvls.cc b/src/nvls.cc index 5504e3b25..3221e6e00 100644 --- a/src/nvls.cc +++ b/src/nvls.cc @@ -15,7 +15,7 @@ namespace mscclpp { -#if (USE_NVLS) +#if (CUDA_NVLS_SUPPORTED) class NvlsConnection::Impl : public std::enable_shared_from_this { public: // use this only for the root of the NVLS @@ -31,10 +31,11 @@ class NvlsConnection::Impl : public std::enable_shared_from_this bindMemory(CUmemGenericAllocationHandle memHandle, size_t devBuffSize); + std::shared_ptr bindMemory(CUdeviceptr devicePtr, size_t devBuffSize); private: friend class NvlsConnection; + CUmemGenericAllocationHandle mcHandle_; CUmulticastObjectProp mcProp_; size_t bufferSize_; @@ -70,8 +71,10 @@ NvlsConnection::Impl::Impl(size_t bufferSize, int numDevices) { throw mscclpp::SysError("getpid() failed", errno); } - INFO(MSCCLPP_COLL, "NVLS handle created on root with size %ld. minGranularity %ld and recommendedGranularity %ld\n", - mcProp_.size, minMcGran_, mcGran_); + INFO(MSCCLPP_COLL, + "NVLS handle created on root with size %ld. minGranularity %ld and recommendedGranularity %ld buffer size is " + "%ld, adjusted size is %ld", + mcProp_.size, minMcGran_, mcGran_, bufferSize, bufferSize_); } NvlsConnection::Impl::Impl(const std::vector& data) { @@ -128,6 +131,8 @@ void NvlsConnection::Impl::addDevice(int cudaDeviceId) { INFO(MSCCLPP_COLL, "NVLS connection created"); } +// TODO(binyli): For cuMemMap, we can not map handle to va with offset not equal to 0. +// Then we don't need to maintain the freeRanges_ list. For different memory, we could map to different mc handle. size_t NvlsConnection::Impl::allocateBuffer(size_t size) { if (freeRanges_.empty()) { throw Error("This NVLS connection mapped more than it was supposed to", ErrorCode::InvalidUsage); @@ -187,24 +192,21 @@ void NvlsConnection::Impl::freeBuffer(size_t offset, size_t size) noexcept { } } -std::shared_ptr NvlsConnection::Impl::bindMemory(CUmemGenericAllocationHandle memHandle, size_t devBuffSize) { +std::shared_ptr NvlsConnection::Impl::bindMemory(CUdeviceptr devicePtr, size_t devBuffSize) { + devBuffSize = ((devBuffSize + minMcGran_ - 1) / minMcGran_) * minMcGran_; size_t offset = allocateBuffer(devBuffSize); - MSCCLPP_CUTHROW(cuMulticastBindMem(mcHandle_, offset /*mcOffset*/, memHandle, 0 /*memOffset*/, devBuffSize, 0)); + MSCCLPP_CUTHROW(cuMulticastBindAddr(mcHandle_, offset /*mcOffset*/, devicePtr, devBuffSize, 0)); char* mcPtr; - - CUmemAccessDesc accessDesc = {}; - accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - int deviceId = -1; - MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId)); - accessDesc.location.id = deviceId; - accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)(&mcPtr), devBuffSize, minMcGran_, 0U, 0)); MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)(mcPtr), devBuffSize, 0, mcHandle_, 0)); - MSCCLPP_CUTHROW(cuMemSetAccess((CUdeviceptr)(mcPtr), devBuffSize, &accessDesc, 1)); + setReadWriteMemoryAccess(mcPtr, devBuffSize); + INFO(MSCCLPP_COLL, "NVLS connection bound memory at offset %ld, size %ld", offset, devBuffSize); auto deleter = [=, self = shared_from_this()](char* ptr) { + int deviceId; CUdevice device; + MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId)); MSCCLPP_CUTHROW(cuDeviceGet(&device, deviceId)); MSCCLPP_CUTHROW(cuMemUnmap((CUdeviceptr)ptr, devBuffSize)); MSCCLPP_CUTHROW(cuMemAddressFree((CUdeviceptr)ptr, devBuffSize)); @@ -214,7 +216,8 @@ std::shared_ptr NvlsConnection::Impl::bindMemory(CUmemGenericAllocationHan return std::shared_ptr(mcPtr, deleter); } -#else // !(USE_NVLS) + +#else // !(CUDA_NVLS_SUPPORTED) class NvlsConnection::Impl { public: // use this only for the root of the NVLS @@ -227,15 +230,15 @@ class NvlsConnection::Impl { std::vector serialize() { throw notSupportedError; } size_t allocateBuffer(size_t) { throw notSupportedError; } void freeBuffer(size_t, size_t) { throw notSupportedError; } - std::shared_ptr bindMemory(CUmemGenericAllocationHandle, size_t) { throw notSupportedError; } + std::shared_ptr bindMemory(CUdeviceptr, size_t) { throw notSupportedError; } void addDevice(int) { throw notSupportedError; } size_t getMinMcGran() { throw notSupportedError; } private: Error notSupportedError = - Error("NVLS is not supported on this CUDA version (< 12.1) or kernel version (< 5.6.0)", ErrorCode::InvalidUsage); + Error("NVLS is not supported on this CUDA version (< 12.3) or kernel version (< 5.6.0)", ErrorCode::InvalidUsage); }; -#endif // !(USE_NVLS) +#endif // !(CUDA_NVLS_SUPPORTED) const int NvlsConnection::DefaultNvlsBufferSize = (1 << 29); @@ -254,25 +257,20 @@ NvlsConnection::NvlsConnection(const std::vector& data) : pimpl_(std::make std::vector NvlsConnection::serialize() { return pimpl_->serialize(); } -std::shared_ptr NvlsConnection::allocateAndBindCuda(size_t size) { - auto mem = allocSharedPhysicalCuda(size, pimpl_->getMinMcGran()); - auto mcPtr = pimpl_->bindMemory(mem->memHandle_, size); - return std::make_shared(mem, mcPtr, size); -} - -std::shared_ptr NvlsConnection::bindAllocatedCuda(CUmemGenericAllocationHandle memHandle, size_t size) { - return pimpl_->bindMemory(memHandle, size); +NvlsConnection::DeviceMulticastPointer NvlsConnection::bindAllocatedMemory(CUdeviceptr devicePtr, size_t size) { + auto mcPtr = pimpl_->bindMemory(devicePtr, size); + return DeviceMulticastPointer((void*)devicePtr, mcPtr, size); } NvlsConnection::DeviceMulticastPointer::DeviceHandle NvlsConnection::DeviceMulticastPointer::deviceHandle() { NvlsConnection::DeviceMulticastPointer::DeviceHandle device; - device.devicePtr = this->deviceMem_->devicePtr_; + device.devicePtr = this->devicePtr_; device.mcPtr = this->mcPtr_.get(); device.bufferSize = this->bufferSize_; return device; }; -char* NvlsConnection::DeviceMulticastPointer::getDevicePtr() { return deviceMem_->devicePtr_; }; +void* NvlsConnection::DeviceMulticastPointer::getDevicePtr() { return devicePtr_; }; size_t NvlsConnection::getMultiCastMinGranularity() { return pimpl_->getMinMcGran(); } diff --git a/src/registered_memory.cc b/src/registered_memory.cc index 0702c497b..1ad97c1b2 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -11,6 +11,65 @@ #include "debug.h" #include "utils_internal.hpp" +#define MSCCLPP_CULOG_WARN(cmd) \ + do { \ + CUresult err = cmd; \ + if (err != CUDA_SUCCESS) { \ + const char* errStr; \ + if (cuGetErrorString(err, &errStr) != CUDA_SUCCESS) { \ + errStr = "failed to get error string"; \ + } \ + WARN("Call to " #cmd " failed, error is %s", errStr); \ + } \ + } while (false) + +namespace { +// Get the recommended granularity for cuMemAddressReserve +size_t getRecommendedGranularity() { +#if (CUDA_NVLS_SUPPORTED) + size_t gran = 0; + int deviceId = -1; + int currentDevice = -1; + MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId)); + MSCCLPP_CUTHROW(cuDeviceGet(¤tDevice, deviceId)); + + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.requestedHandleTypes = + (CUmemAllocationHandleType)(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR | CU_MEM_HANDLE_TYPE_FABRIC); + prop.location.id = currentDevice; + MSCCLPP_CUTHROW(cuMemGetAllocationGranularity(&gran, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); + return gran; +#else + throw mscclpp::Error("Only support GPU with NVLS support", mscclpp::ErrorCode::InvalidUsage); +#endif +} + +CUmemAllocationHandleType getNvlsCompatibleMemHandleType() { +#if (CUDA_NVLS_SUPPORTED) + return CU_MEM_HANDLE_TYPE_FABRIC; +#else + throw mscclpp::Error("Only support GPU with NVLS support", mscclpp::ErrorCode::InvalidUsage); +#endif +} + +// Check if ptr is allocaed by cuMemMap +bool isCuMemMapAllocated(void* ptr) { + CUmemGenericAllocationHandle handle; + CUresult result = cuMemRetainAllocationHandle(&handle, ptr); + if (result != CUDA_SUCCESS) { + return false; + } + MSCCLPP_CUTHROW(cuMemRelease(handle)); + if (!mscclpp::isNvlsSupported()) { + throw mscclpp::Error("cuMemMap is used in env without NVLS support", mscclpp::ErrorCode::InvalidUsage); + } + return true; +} + +} // namespace + namespace mscclpp { RegisteredMemory::Impl::Impl(void* data, size_t size, TransportFlags transports, Context::Impl& contextImpl) @@ -23,15 +82,24 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, TransportFlags transports, if (transports.has(Transport::CudaIpc)) { TransportInfo transportInfo; transportInfo.transport = Transport::CudaIpc; - cudaIpcMemHandle_t handle; void* baseDataPtr; size_t baseDataSize; // dummy MSCCLPP_CUTHROW(cuMemGetAddressRange((CUdeviceptr*)&baseDataPtr, &baseDataSize, (CUdeviceptr)data)); - MSCCLPP_CUDATHROW(cudaIpcGetMemHandle(&handle, baseDataPtr)); - // TODO: bug with offset of base? - transportInfo.cudaIpcBaseHandle = handle; - transportInfo.cudaIpcOffsetFromBase = (char*)data - (char*)baseDataPtr; + this->isCuMemMapAlloc = isCuMemMapAllocated(baseDataPtr); + if (this->isCuMemMapAlloc) { + CUmemGenericAllocationHandle handle; + MSCCLPP_CUTHROW(cuMemRetainAllocationHandle(&handle, baseDataPtr)); + MSCCLPP_CUTHROW( + cuMemExportToShareableHandle(transportInfo.shareableHandle, handle, getNvlsCompatibleMemHandleType(), 0)); + transportInfo.offsetFromBase = (char*)data - (char*)baseDataPtr; + } else { + cudaIpcMemHandle_t handle; + MSCCLPP_CUDATHROW(cudaIpcGetMemHandle(&handle, baseDataPtr)); + // TODO: bug with offset of base? + transportInfo.cudaIpcBaseHandle = handle; + transportInfo.cudaIpcOffsetFromBase = (char*)data - (char*)baseDataPtr; + } this->transportInfos.push_back(transportInfo); } if ((transports & AllIBTransports).any()) { @@ -75,6 +143,8 @@ MSCCLPP_API_CPP std::vector RegisteredMemory::serialize() { std::copy_n(reinterpret_cast(&pimpl_->size), sizeof(pimpl_->size), std::back_inserter(result)); std::copy_n(reinterpret_cast(&pimpl_->hostHash), sizeof(pimpl_->hostHash), std::back_inserter(result)); std::copy_n(reinterpret_cast(&pimpl_->pidHash), sizeof(pimpl_->pidHash), std::back_inserter(result)); + std::copy_n(reinterpret_cast(&pimpl_->isCuMemMapAlloc), sizeof(pimpl_->isCuMemMapAlloc), + std::back_inserter(result)); std::copy_n(reinterpret_cast(&pimpl_->transports), sizeof(pimpl_->transports), std::back_inserter(result)); if (pimpl_->transportInfos.size() > static_cast(std::numeric_limits::max())) { throw mscclpp::Error("Too many transport info entries", ErrorCode::InternalError); @@ -84,10 +154,17 @@ MSCCLPP_API_CPP std::vector RegisteredMemory::serialize() { for (auto& entry : pimpl_->transportInfos) { std::copy_n(reinterpret_cast(&entry.transport), sizeof(entry.transport), std::back_inserter(result)); if (entry.transport == Transport::CudaIpc) { - std::copy_n(reinterpret_cast(&entry.cudaIpcBaseHandle), sizeof(entry.cudaIpcBaseHandle), - std::back_inserter(result)); - std::copy_n(reinterpret_cast(&entry.cudaIpcOffsetFromBase), sizeof(entry.cudaIpcOffsetFromBase), - std::back_inserter(result)); + if (pimpl_->isCuMemMapAlloc) { + std::copy_n(reinterpret_cast(&entry.shareableHandle), sizeof(entry.shareableHandle), + std::back_inserter(result)); + std::copy_n(reinterpret_cast(&entry.offsetFromBase), sizeof(entry.offsetFromBase), + std::back_inserter(result)); + } else { + std::copy_n(reinterpret_cast(&entry.cudaIpcBaseHandle), sizeof(entry.cudaIpcBaseHandle), + std::back_inserter(result)); + std::copy_n(reinterpret_cast(&entry.cudaIpcOffsetFromBase), sizeof(entry.cudaIpcOffsetFromBase), + std::back_inserter(result)); + } } else if (AllIBTransports.has(entry.transport)) { std::copy_n(reinterpret_cast(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result)); } else { @@ -111,6 +188,8 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) { it += sizeof(this->hostHash); std::copy_n(it, sizeof(this->pidHash), reinterpret_cast(&this->pidHash)); it += sizeof(this->pidHash); + std::copy_n(it, sizeof(this->isCuMemMapAlloc), reinterpret_cast(&this->isCuMemMapAlloc)); + it += sizeof(this->isCuMemMapAlloc); std::copy_n(it, sizeof(this->transports), reinterpret_cast(&this->transports)); it += sizeof(this->transports); int8_t transportCount; @@ -121,12 +200,19 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) { std::copy_n(it, sizeof(transportInfo.transport), reinterpret_cast(&transportInfo.transport)); it += sizeof(transportInfo.transport); if (transportInfo.transport == Transport::CudaIpc) { - std::copy_n(it, sizeof(transportInfo.cudaIpcBaseHandle), - reinterpret_cast(&transportInfo.cudaIpcBaseHandle)); - it += sizeof(transportInfo.cudaIpcBaseHandle); - std::copy_n(it, sizeof(transportInfo.cudaIpcOffsetFromBase), - reinterpret_cast(&transportInfo.cudaIpcOffsetFromBase)); - it += sizeof(transportInfo.cudaIpcOffsetFromBase); + if (this->isCuMemMapAlloc) { + std::copy_n(it, sizeof(transportInfo.shareableHandle), reinterpret_cast(&transportInfo.shareableHandle)); + it += sizeof(transportInfo.shareableHandle); + std::copy_n(it, sizeof(transportInfo.offsetFromBase), reinterpret_cast(&transportInfo.offsetFromBase)); + it += sizeof(transportInfo.offsetFromBase); + } else { + std::copy_n(it, sizeof(transportInfo.cudaIpcBaseHandle), + reinterpret_cast(&transportInfo.cudaIpcBaseHandle)); + it += sizeof(transportInfo.cudaIpcBaseHandle); + std::copy_n(it, sizeof(transportInfo.cudaIpcOffsetFromBase), + reinterpret_cast(&transportInfo.cudaIpcOffsetFromBase)); + it += sizeof(transportInfo.cudaIpcOffsetFromBase); + } } else if (AllIBTransports.has(transportInfo.transport)) { std::copy_n(it, sizeof(transportInfo.ibMrInfo), reinterpret_cast(&transportInfo.ibMrInfo)); it += sizeof(transportInfo.ibMrInfo); @@ -148,8 +234,18 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) { // The memory is local to the machine but not to the process, so we need to open the CUDA IPC handle auto entry = getTransportInfo(Transport::CudaIpc); void* base; - MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&base, entry.cudaIpcBaseHandle, cudaIpcMemLazyEnablePeerAccess)); - this->data = static_cast(base) + entry.cudaIpcOffsetFromBase; + if (this->isCuMemMapAlloc) { + CUmemGenericAllocationHandle handle; + MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&handle, entry.shareableHandle, getNvlsCompatibleMemHandleType())); + size_t gran = getRecommendedGranularity(); + MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&base, this->size, gran, 0, 0)); + MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)base, this->size, 0, handle, 0)); + setReadWriteMemoryAccess(base, this->size); + this->data = static_cast(base) + entry.offsetFromBase; + } else { + MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&base, entry.cudaIpcBaseHandle, cudaIpcMemLazyEnablePeerAccess)); + this->data = static_cast(base) + entry.cudaIpcOffsetFromBase; + } INFO(MSCCLPP_P2P, "Opened CUDA IPC handle at pointer %p", this->data); } else { // No valid data pointer can be set @@ -161,11 +257,22 @@ RegisteredMemory::Impl::~Impl() { // Close the CUDA IPC handle if it was opened during deserialization if (data && transports.has(Transport::CudaIpc) && getHostHash() == this->hostHash && getPidHash() != this->pidHash) { void* base = static_cast(data) - getTransportInfo(Transport::CudaIpc).cudaIpcOffsetFromBase; - cudaError_t err = cudaIpcCloseMemHandle(base); - if (err != cudaSuccess) { - WARN("Failed to close CUDA IPC handle at pointer %p: %s", base, cudaGetErrorString(err)); + if (this->isCuMemMapAlloc) { + CUmemGenericAllocationHandle handle; + size_t size = 0; + MSCCLPP_CULOG_WARN(cuMemRetainAllocationHandle(&handle, base)); + MSCCLPP_CULOG_WARN(cuMemRelease(handle)); + MSCCLPP_CULOG_WARN(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)base)); + MSCCLPP_CULOG_WARN(cuMemUnmap((CUdeviceptr)base, size)); + MSCCLPP_CULOG_WARN(cuMemRelease(handle)); + MSCCLPP_CULOG_WARN(cuMemAddressFree((CUdeviceptr)base, size)); } else { - INFO(MSCCLPP_P2P, "Closed CUDA IPC handle at pointer %p", base); + cudaError_t err = cudaIpcCloseMemHandle(base); + if (err != cudaSuccess) { + WARN("Failed to close CUDA IPC handle at pointer %p: %s", base, cudaGetErrorString(err)); + } else { + INFO(MSCCLPP_P2P, "Closed CUDA IPC handle at pointer %p", base); + } } data = nullptr; } diff --git a/src/utils.cc b/src/utils.cc index 8475f2f60..fb470a4ab 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -68,12 +68,19 @@ std::string getHostName(int maxlen, const char delim) { } bool isNvlsSupported() { -#if (CUDART_VERSION >= 12010) - CUdevice dev; - int isNvlsSupported; - MSCCLPP_CUTHROW(cuCtxGetDevice(&dev)); - MSCCLPP_CUTHROW(cuDeviceGetAttribute(&isNvlsSupported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, dev)); - return isNvlsSupported == 1; + [[maybe_unused]] static bool result = false; + [[maybe_unused]] static bool isChecked = false; +#if (CUDA_NVLS_SUPPORTED) + if (!isChecked) { + int isMulticastSupported; + int isFabricSupported; + CUdevice dev; + MSCCLPP_CUTHROW(cuCtxGetDevice(&dev)); + MSCCLPP_CUTHROW(cuDeviceGetAttribute(&isMulticastSupported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, dev)); + MSCCLPP_CUTHROW(cuDeviceGetAttribute(&isFabricSupported, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, dev)); + result = (isMulticastSupported == 1 && isFabricSupported == 1); + } + return result; #endif return false; } diff --git a/test/execution-files/allreduce_nvls.json b/test/execution-files/allreduce_nvls.json new file mode 100644 index 000000000..069b5df9d --- /dev/null +++ b/test/execution-files/allreduce_nvls.json @@ -0,0 +1,1458 @@ +{ + "name": "allreduce_nvls", + "colletive": "allreduce", + "protocol": "Simple", + "inplace": true, + "gpus": [ + { + "id": 0, + "inputChunks": 8, + "outputChunks": 0, + "scratchChunks": 0, + "chunkGroups": 8, + "threadblocks": [ + { + "id": 0, + "ops": [ + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 0 + }, + { + "id": 1, + "off": 0 + }, + { + "id": 2, + "off": 0 + }, + { + "id": 3, + "off": 0 + }, + { + "id": 4, + "off": 0 + }, + { + "id": 5, + "off": 0 + }, + { + "id": 6, + "off": 0 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 0 + }, + { + "id": 1, + "off": 0 + }, + { + "id": 2, + "off": 0 + }, + { + "id": 3, + "off": 0 + }, + { + "id": 4, + "off": 0 + }, + { + "id": 5, + "off": 0 + }, + { + "id": 6, + "off": 0 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 0, + "step": 1 + } + ] + }, + { + "name": "glres", + "i_cids": [ + { + "id": 0 + } + ], + "o_cids": [ + { + "id": 0 + } + ], + "src": 0, + "srcbuff": "i", + "srcoff": 0, + "dst": 0, + "dstbuff": "i", + "dstoff": 0, + "ctype": "nvls", + "cnt": 1 + } + ], + "channels": [ + { + "src": "i", + "dst": "i", + "ctype": "nvls", + "cids": [ + 0 + ] + }, + { + "src": "i", + "dst": "i", + "ctype": "sm", + "cids": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6 + ] + } + ] + } + ], + "channels": [ + { + "buff": "i", + "type": "nvls", + "rankGroups": [ + { + "size": 8, + "ranks": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7 + ] + } + ] + }, + { + "srcbuff": "i", + "dstbuff": "i", + "type": "sm", + "connectedTo": [ + 1, + 2, + 3, + 4, + 5, + 6, + 7 + ] + } + ] + }, + { + "id": 1, + "inputChunks": 8, + "outputChunks": 0, + "scratchChunks": 0, + "chunkGroups": 8, + "threadblocks": [ + { + "id": 0, + "ops": [ + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 1 + }, + { + "id": 1, + "off": 1 + }, + { + "id": 2, + "off": 1 + }, + { + "id": 3, + "off": 1 + }, + { + "id": 4, + "off": 1 + }, + { + "id": 5, + "off": 1 + }, + { + "id": 6, + "off": 1 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 1 + }, + { + "id": 1, + "off": 1 + }, + { + "id": 2, + "off": 1 + }, + { + "id": 3, + "off": 1 + }, + { + "id": 4, + "off": 1 + }, + { + "id": 5, + "off": 1 + }, + { + "id": 6, + "off": 1 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 0, + "step": 1 + } + ] + }, + { + "name": "glres", + "i_cids": [ + { + "id": 0 + } + ], + "o_cids": [ + { + "id": 0 + } + ], + "src": 1, + "srcbuff": "i", + "srcoff": 1, + "dst": 1, + "dstbuff": "i", + "dstoff": 1, + "ctype": "nvls", + "cnt": 1 + } + ], + "channels": [ + { + "src": "i", + "dst": "i", + "ctype": "nvls", + "cids": [ + 0 + ] + }, + { + "src": "i", + "dst": "i", + "ctype": "sm", + "cids": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6 + ] + } + ] + } + ], + "channels": [ + { + "buff": "i", + "type": "nvls", + "rankGroups": [ + { + "size": 8, + "ranks": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7 + ] + } + ] + }, + { + "srcbuff": "i", + "dstbuff": "i", + "type": "sm", + "connectedTo": [ + 0, + 2, + 3, + 4, + 5, + 6, + 7 + ] + } + ] + }, + { + "id": 2, + "inputChunks": 8, + "outputChunks": 0, + "scratchChunks": 0, + "chunkGroups": 8, + "threadblocks": [ + { + "id": 0, + "ops": [ + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 2 + }, + { + "id": 1, + "off": 2 + }, + { + "id": 2, + "off": 2 + }, + { + "id": 3, + "off": 2 + }, + { + "id": 4, + "off": 2 + }, + { + "id": 5, + "off": 2 + }, + { + "id": 6, + "off": 2 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 2 + }, + { + "id": 1, + "off": 2 + }, + { + "id": 2, + "off": 2 + }, + { + "id": 3, + "off": 2 + }, + { + "id": 4, + "off": 2 + }, + { + "id": 5, + "off": 2 + }, + { + "id": 6, + "off": 2 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 0, + "step": 1 + } + ] + }, + { + "name": "glres", + "i_cids": [ + { + "id": 0 + } + ], + "o_cids": [ + { + "id": 0 + } + ], + "src": 2, + "srcbuff": "i", + "srcoff": 2, + "dst": 2, + "dstbuff": "i", + "dstoff": 2, + "ctype": "nvls", + "cnt": 1 + } + ], + "channels": [ + { + "src": "i", + "dst": "i", + "ctype": "sm", + "cids": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6 + ] + }, + { + "src": "i", + "dst": "i", + "ctype": "nvls", + "cids": [ + 0 + ] + } + ] + } + ], + "channels": [ + { + "srcbuff": "i", + "dstbuff": "i", + "type": "sm", + "connectedTo": [ + 0, + 1, + 3, + 4, + 5, + 6, + 7 + ] + }, + { + "buff": "i", + "type": "nvls", + "rankGroups": [ + { + "size": 8, + "ranks": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7 + ] + } + ] + } + ] + }, + { + "id": 3, + "inputChunks": 8, + "outputChunks": 0, + "scratchChunks": 0, + "chunkGroups": 8, + "threadblocks": [ + { + "id": 0, + "ops": [ + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 3 + }, + { + "id": 1, + "off": 3 + }, + { + "id": 2, + "off": 3 + }, + { + "id": 3, + "off": 3 + }, + { + "id": 4, + "off": 3 + }, + { + "id": 5, + "off": 3 + }, + { + "id": 6, + "off": 3 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 3 + }, + { + "id": 1, + "off": 3 + }, + { + "id": 2, + "off": 3 + }, + { + "id": 3, + "off": 3 + }, + { + "id": 4, + "off": 3 + }, + { + "id": 5, + "off": 3 + }, + { + "id": 6, + "off": 3 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 0, + "step": 1 + } + ] + }, + { + "name": "glres", + "i_cids": [ + { + "id": 0 + } + ], + "o_cids": [ + { + "id": 0 + } + ], + "src": 3, + "srcbuff": "i", + "srcoff": 3, + "dst": 3, + "dstbuff": "i", + "dstoff": 3, + "ctype": "nvls", + "cnt": 1 + } + ], + "channels": [ + { + "src": "i", + "dst": "i", + "ctype": "sm", + "cids": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6 + ] + }, + { + "src": "i", + "dst": "i", + "ctype": "nvls", + "cids": [ + 0 + ] + } + ] + } + ], + "channels": [ + { + "srcbuff": "i", + "dstbuff": "i", + "type": "sm", + "connectedTo": [ + 0, + 1, + 2, + 4, + 5, + 6, + 7 + ] + }, + { + "buff": "i", + "type": "nvls", + "rankGroups": [ + { + "size": 8, + "ranks": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7 + ] + } + ] + } + ] + }, + { + "id": 4, + "inputChunks": 8, + "outputChunks": 0, + "scratchChunks": 0, + "chunkGroups": 8, + "threadblocks": [ + { + "id": 0, + "ops": [ + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 4 + }, + { + "id": 1, + "off": 4 + }, + { + "id": 2, + "off": 4 + }, + { + "id": 3, + "off": 4 + }, + { + "id": 4, + "off": 4 + }, + { + "id": 5, + "off": 4 + }, + { + "id": 6, + "off": 4 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 4 + }, + { + "id": 1, + "off": 4 + }, + { + "id": 2, + "off": 4 + }, + { + "id": 3, + "off": 4 + }, + { + "id": 4, + "off": 4 + }, + { + "id": 5, + "off": 4 + }, + { + "id": 6, + "off": 4 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 0, + "step": 1 + } + ] + }, + { + "name": "glres", + "i_cids": [ + { + "id": 0 + } + ], + "o_cids": [ + { + "id": 0 + } + ], + "src": 4, + "srcbuff": "i", + "srcoff": 4, + "dst": 4, + "dstbuff": "i", + "dstoff": 4, + "ctype": "nvls", + "cnt": 1 + } + ], + "channels": [ + { + "src": "i", + "dst": "i", + "ctype": "sm", + "cids": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6 + ] + }, + { + "src": "i", + "dst": "i", + "ctype": "nvls", + "cids": [ + 0 + ] + } + ] + } + ], + "channels": [ + { + "srcbuff": "i", + "dstbuff": "i", + "type": "sm", + "connectedTo": [ + 0, + 1, + 2, + 3, + 5, + 6, + 7 + ] + }, + { + "buff": "i", + "type": "nvls", + "rankGroups": [ + { + "size": 8, + "ranks": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7 + ] + } + ] + } + ] + }, + { + "id": 5, + "inputChunks": 8, + "outputChunks": 0, + "scratchChunks": 0, + "chunkGroups": 8, + "threadblocks": [ + { + "id": 0, + "ops": [ + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 5 + }, + { + "id": 1, + "off": 5 + }, + { + "id": 2, + "off": 5 + }, + { + "id": 3, + "off": 5 + }, + { + "id": 4, + "off": 5 + }, + { + "id": 5, + "off": 5 + }, + { + "id": 6, + "off": 5 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 5 + }, + { + "id": 1, + "off": 5 + }, + { + "id": 2, + "off": 5 + }, + { + "id": 3, + "off": 5 + }, + { + "id": 4, + "off": 5 + }, + { + "id": 5, + "off": 5 + }, + { + "id": 6, + "off": 5 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 0, + "step": 1 + } + ] + }, + { + "name": "glres", + "i_cids": [ + { + "id": 0 + } + ], + "o_cids": [ + { + "id": 0 + } + ], + "src": 5, + "srcbuff": "i", + "srcoff": 5, + "dst": 5, + "dstbuff": "i", + "dstoff": 5, + "ctype": "nvls", + "cnt": 1 + } + ], + "channels": [ + { + "src": "i", + "dst": "i", + "ctype": "sm", + "cids": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6 + ] + }, + { + "src": "i", + "dst": "i", + "ctype": "nvls", + "cids": [ + 0 + ] + } + ] + } + ], + "channels": [ + { + "srcbuff": "i", + "dstbuff": "i", + "type": "sm", + "connectedTo": [ + 0, + 1, + 2, + 3, + 4, + 6, + 7 + ] + }, + { + "buff": "i", + "type": "nvls", + "rankGroups": [ + { + "size": 8, + "ranks": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7 + ] + } + ] + } + ] + }, + { + "id": 6, + "inputChunks": 8, + "outputChunks": 0, + "scratchChunks": 0, + "chunkGroups": 8, + "threadblocks": [ + { + "id": 0, + "ops": [ + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 6 + }, + { + "id": 1, + "off": 6 + }, + { + "id": 2, + "off": 6 + }, + { + "id": 3, + "off": 6 + }, + { + "id": 4, + "off": 6 + }, + { + "id": 5, + "off": 6 + }, + { + "id": 6, + "off": 6 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 6 + }, + { + "id": 1, + "off": 6 + }, + { + "id": 2, + "off": 6 + }, + { + "id": 3, + "off": 6 + }, + { + "id": 4, + "off": 6 + }, + { + "id": 5, + "off": 6 + }, + { + "id": 6, + "off": 6 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 0, + "step": 1 + } + ] + }, + { + "name": "glres", + "i_cids": [ + { + "id": 0 + } + ], + "o_cids": [ + { + "id": 0 + } + ], + "src": 6, + "srcbuff": "i", + "srcoff": 6, + "dst": 6, + "dstbuff": "i", + "dstoff": 6, + "ctype": "nvls", + "cnt": 1 + } + ], + "channels": [ + { + "src": "i", + "dst": "i", + "ctype": "sm", + "cids": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6 + ] + }, + { + "src": "i", + "dst": "i", + "ctype": "nvls", + "cids": [ + 0 + ] + } + ] + } + ], + "channels": [ + { + "srcbuff": "i", + "dstbuff": "i", + "type": "sm", + "connectedTo": [ + 0, + 1, + 2, + 3, + 4, + 5, + 7 + ] + }, + { + "buff": "i", + "type": "nvls", + "rankGroups": [ + { + "size": 8, + "ranks": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7 + ] + } + ] + } + ] + }, + { + "id": 7, + "inputChunks": 8, + "outputChunks": 0, + "scratchChunks": 0, + "chunkGroups": 8, + "threadblocks": [ + { + "id": 0, + "ops": [ + { + "name": "signal", + "o_buff": { + "src": "i", + "dst": "i" + }, + "o_cids": [ + { + "id": 0, + "off": 7 + }, + { + "id": 1, + "off": 7 + }, + { + "id": 2, + "off": 7 + }, + { + "id": 3, + "off": 7 + }, + { + "id": 4, + "off": 7 + }, + { + "id": 5, + "off": 7 + }, + { + "id": 6, + "off": 7 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "wait", + "i_buff": { + "src": "i", + "dst": "i" + }, + "i_cids": [ + { + "id": 0, + "off": 7 + }, + { + "id": 1, + "off": 7 + }, + { + "id": 2, + "off": 7 + }, + { + "id": 3, + "off": 7 + }, + { + "id": 4, + "off": 7 + }, + { + "id": 5, + "off": 7 + }, + { + "id": 6, + "off": 7 + } + ], + "ctype": "sm", + "cnt": 1 + }, + { + "name": "nop", + "deps": [ + { + "tb": 0, + "step": 1 + } + ] + }, + { + "name": "glres", + "i_cids": [ + { + "id": 0 + } + ], + "o_cids": [ + { + "id": 0 + } + ], + "src": 7, + "srcbuff": "i", + "srcoff": 7, + "dst": 7, + "dstbuff": "i", + "dstoff": 7, + "ctype": "nvls", + "cnt": 1 + } + ], + "channels": [ + { + "src": "i", + "dst": "i", + "ctype": "sm", + "cids": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6 + ] + }, + { + "src": "i", + "dst": "i", + "ctype": "nvls", + "cids": [ + 0 + ] + } + ] + } + ], + "channels": [ + { + "srcbuff": "i", + "dstbuff": "i", + "type": "sm", + "connectedTo": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6 + ] + }, + { + "buff": "i", + "type": "nvls", + "rankGroups": [ + { + "size": 8, + "ranks": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7 + ] + } + ] + } + ] + } + ], + "num_threads_per_block": 1024, + "use_double_scratch_buffer": false +} diff --git a/test/executor_test.cc b/test/executor_test.cc index 2f6d9cf5d..3fc0b1e21 100644 --- a/test/executor_test.cc +++ b/test/executor_test.cc @@ -131,7 +131,11 @@ int main(int argc, char* argv[]) { } mscclpp::ExecutionPlan plan(executionPlanName, executionPlanPath); +#if (CUDA_NVLS_SUPPORTED) + std::shared_ptr sendbuff = mscclpp::allocSharedPhysicalCuda(bufferSize); +#else std::shared_ptr sendbuff = mscclpp::allocExtSharedCuda(bufferSize); +#endif std::vector dataHost(bufferSize / sizeof(int), rank); MSCCLPP_CUDATHROW(cudaMemcpy(sendbuff.get(), dataHost.data(), bufferSize, cudaMemcpyHostToDevice)); double deltaSec = benchTime(rank, bootstrap, executor, plan, sendbuff, bufferSize, niters, ngraphIters, packetType); diff --git a/test/nvls_test.cu b/test/nvls_test.cu index 55ece3fcf..7a1a54ade 100644 --- a/test/nvls_test.cu +++ b/test/nvls_test.cu @@ -1,18 +1,17 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include - -#if (USE_NVLS) -#include -#include -#include #include +#include #include #include #include #include +#if (CUDA_NVLS_SUPPORTED) +#include +#include +#include #define CUCHECK(cmd) \ do { \ @@ -41,31 +40,31 @@ #define MULTIMEM_LD(val, ptr) #endif -__global__ void init_kernel(float* uc_ptr, int size, int myrank, int nranks) { - for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < size; idx += blockDim.x * gridDim.x) { +__global__ void init_kernel(float* uc_ptr, size_t size, int myrank, int nranks) { + for (size_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < size; idx += blockDim.x * gridDim.x) { uc_ptr[idx] = myrank + idx; } } -__global__ void check_correctness(float* uc_ptr, int size, int myrank, int nranks) { - for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < size; idx += blockDim.x * gridDim.x) { +__global__ void check_correctness(float* uc_ptr, size_t size, int myrank, int nranks) { + for (size_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < size; idx += blockDim.x * gridDim.x) { float expected = (float)((nranks * (nranks - 1)) / 2 + nranks * idx); if (abs(uc_ptr[idx] - expected) > 0.01 * expected) { - printf("error! idx %d: %f != %f\n", idx, uc_ptr[idx], expected); + printf("error! idx %ld: %f != %f\n", idx, uc_ptr[idx], expected); } } } -__global__ void testing(float* mc_ptr, int size, int myrank, int nranks) { +__global__ void testing(float* mc_ptr, size_t size, int myrank, int nranks) { // for allreduce we dont even need an UC pointer. just using same mc_ptr for in-place reduction // line is assumed to be 16B 4 ints of 8 halves - int my_st = ((int64_t)size * (int64_t)myrank) / (int64_t)nranks; - int my_en = ((int64_t)size * (int64_t)(myrank + 1)) / (int64_t)nranks; + size_t my_st = ((int64_t)size * (int64_t)myrank) / (int64_t)nranks; + size_t my_en = ((int64_t)size * (int64_t)(myrank + 1)) / (int64_t)nranks; - int my_offset = (threadIdx.x + blockIdx.x * blockDim.x) * 4; - int my_step = blockDim.x * gridDim.x * 4; + size_t my_offset = (threadIdx.x + blockIdx.x * blockDim.x) * 4; + size_t my_step = blockDim.x * gridDim.x * 4; - for (int idx = my_st + my_offset; idx < my_en; idx += my_step) { + for (size_t idx = my_st + my_offset; idx < my_en; idx += my_step) { [[maybe_unused]] uint4 val; MULTIMEM_LD(val, mc_ptr + idx); MULTIMEM_ST(val, mc_ptr + idx); @@ -80,7 +79,7 @@ int main() { cudaSetDevice(myrank); - size_t size = 1024 * 1024 * 512; + size_t size = 1024ULL * 1024ULL * 512ULL * 16; CUmemAllocationHandleType handleType = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; CUmulticastObjectProp mcProp = {}; @@ -138,7 +137,7 @@ int main() { prop.requestedHandleTypes = handleType; // allocate physical memory (data buffer) - CUCHECK(cuMemCreate(&memhandle, size, &prop, 0 /*flags*/)); + CUCHECK(cuMemCreate(&memhandle, mcSize, &prop, 0 /*flags*/)); void* uc_va; void* mc_va; @@ -148,14 +147,14 @@ int main() { accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; // Map a VA to UC space - CUCHECK(cuMemAddressReserve((CUdeviceptr*)&uc_va, size, minGran, 0U, 0)); - cudaMemset(uc_va, 0, size); - CUCHECK(cuMemMap((CUdeviceptr)uc_va, size, 0, memhandle, 0)); + CUCHECK(cuMemAddressReserve((CUdeviceptr*)&uc_va, mcSize, minGran, 0U, 0)); + cudaMemset(uc_va, 0, mcSize); + CUCHECK(cuMemMap((CUdeviceptr)uc_va, mcSize, 0, memhandle, 0)); // set access on UC address - CUCHECK(cuMemSetAccess((CUdeviceptr)uc_va, size, &accessDesc, 1)); + CUCHECK(cuMemSetAccess((CUdeviceptr)uc_va, mcSize, &accessDesc, 1)); // everyone binds memory to the multicast - CUCHECK(cuMulticastBindMem(handle, 0 /*mcOffset*/, memhandle, 0 /*memOffset*/, size, 0)); + CUCHECK(cuMulticastBindAddr(handle, 0 /*mcOffset*/, (CUdeviceptr)uc_va, mcSize, 0)); MPI_Barrier(MPI_COMM_WORLD); // usual VA business: map both MC and PA to two different VA addresses @@ -203,11 +202,11 @@ int main() { return 0; } -#else // !(USE_NVLS) +#else // !(CUDA_NVLS_SUPPORTED) int main() { printf("This test requires NVLS to be enabled\n"); return 0; } -#endif // !(USE_NVLS) +#endif // !(CUDA_NVLS_SUPPORTED)