Skip to content

Commit

Permalink
[Distributed] all_reduce op and distributed info in graphs (#284)
Browse files Browse the repository at this point in the history
* add the `all_reduce` op
* add nccl-related headers and libs when building tasks (as a new pass
include_nccl_pass)
* We now have a example of `all_reduce(relu(x * w))` in
`./examples/distributed/test.py`

---------

Co-authored-by: Hanjie <[email protected]>
  • Loading branch information
soodoshll and hjjq authored Jun 29, 2023
1 parent e0c8046 commit 7c52c9d
Show file tree
Hide file tree
Showing 18 changed files with 886 additions and 122 deletions.
104 changes: 38 additions & 66 deletions examples/distributed/test.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,65 @@
"""
Testing script for distributed components for hidet
To debug, set the environment variable NCCL_DEBUG=INFO
To install nccl, run
pip install nvidia-nccl-cu11==2.18.3
Or
pip install nvidia-nccl-cu12==2.18.3
"""
import hidet
import multiprocessing
from multiprocessing import Process
import numpy
import argparse
import atexit
import os

import hidet
import hidet.cuda.nccl
from hidet.cuda import nccl
from hidet.cuda.nccl import NcclUniqueId, NcclDataType, NcclRedOp, nccl_library_filename
from hidet.ffi import runtime_api
from hidet.lang import attrs
from hidet.ir.primitives.cuda.nccl import all_reduce
from hidet.ir.type import data_type
from hidet.utils import prod
from hidet.drivers import build_ir_module
from hidet.cuda.nccl.libinfo import get_nccl_include_dirs, get_nccl_library_search_dirs
from hidet.runtime import load_compiled_module

print("NCCL version:", nccl.nccl_version())

parser = argparse.ArgumentParser()
parser.add_argument("n_gpus", type=int)
parser.add_argument("reduce_op", choices=['sum', 'prod', 'max', 'min', 'avg'])
args = parser.parse_args()

def run(world_size, rank, shared_id, barrier):
def run(world_size, rank):
numpy.random.seed(rank)

# Initialize unique id
if rank == 0:
nccl.init_unique_id(shared_id)

barrier.wait()
hidet.cuda.set_device(rank)
hidet.distributed.init_process_group(init_method='file://tmp', world_size=world_size, rank=rank)
hidet.distributed.set_nccl_comms()

print('initialize', rank)
# Create NcclCommunicator and set the cuda context
# this part should be moved into CompiledGraph in the future
comm = nccl.create_comm(world_size, shared_id, rank)
comms_array = nccl.comms_to_array([comm])
runtime_api.set_nccl_comms(comms_array)

# Initialize send and receive buffer
device = f"cuda:{rank}"
send = hidet.randn([2, 2], device=device)
recv = hidet.empty([2, 2], device=device)

print(rank, send)

dtype = data_type('float32')
shape = [2, 2]
nbytes = dtype.nbytes * prod(shape)

# Define IRModule
with hidet.script_module() as script_module:
@hidet.script
def launch(send: dtype[shape], recv: dtype[shape]):
attrs.func_kind = 'public'
all_reduce(0, send, recv, nbytes, dtype, getattr(NcclRedOp, args.reduce_op))

# Build
ir_module = script_module.ir_module()
ir_module.target = 'cuda'
ir_module.include_dirs.extend(get_nccl_include_dirs())
ir_module.linking_dirs.extend(get_nccl_library_search_dirs())
ir_module.include_headers.append(["nccl.h"])
ir_module.linking_libs.append(":" + nccl_library_filename())
out_dir = f'./.cache/all_reduce_{rank}'

build_ir_module(ir_module, out_dir, target='cuda')
compiled_module = load_compiled_module(out_dir)

compiled_module(send, recv)
s = hidet.cuda.current_stream()
s.synchronize()
print(rank, recv)
x = hidet.randn([1, 3], device=device)
w = hidet.randn([3, 2], device=device)

# test runtime distributed op
hidet.distributed.all_reduce(w, 'avg')
print(w)

# Create Computation Graph
x_symb = hidet.symbol_like(x)
w_symb = hidet.symbol_like(w)
y_local = hidet.ops.relu(x_symb @ w_symb)
y_sync = hidet.ops.all_reduce(y_local, args.reduce_op)
graph = hidet.trace_from([y_local, y_sync], inputs=[x_symb, w_symb])
opt_graph = hidet.graph.optimize(graph)
compiled = opt_graph.build()
y_local, y_sync = compiled(x, w)

hidet.cuda.current_stream().synchronize()
print(f"process {rank}\nbefore allreduce:{y_local}\nafter allreduce:{y_sync}\n", end='')
atexit._run_exitfuncs()

if os.path.exists('tmp'):
os.remove('tmp')

world_size = args.n_gpus

# Barrier to ensure unique id is created
barrier = multiprocessing.Barrier(world_size)

# Create a unique id object in shared memory
shared_id = multiprocessing.Value(NcclUniqueId, lock=False)

processes = [Process(target=run, args=(world_size, i, shared_id, barrier)) for i in range(world_size)]
processes = [Process(target=run, args=(world_size, i)) for i in range(world_size)]

for p in processes:
p.start()
Expand Down
1 change: 1 addition & 0 deletions python/hidet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from . import drivers
from . import logging
from . import cuda
from . import distributed

from .version import __version__

Expand Down
15 changes: 13 additions & 2 deletions python/hidet/cuda/nccl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .comm import create_comm, NcclUniqueId, NcclDataType, NcclRedOp, comms_to_array, init_unique_id, dtype_to_nccl
from .ffi import nccl_version, nccl_library_filename
from .ffi import nccl_available, nccl_version, nccl_library_filename
from .comm import (
create_comm,
NcclUniqueId,
NcclDataType,
NcclRedOp,
comms_to_array,
create_unique_id,
dtype_to_nccl,
NcclCommunicator,
str_to_nccl_op,
NCCL_SPLIT_NOCOLOR,
)
45 changes: 38 additions & 7 deletions python/hidet/cuda/nccl/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import IntEnum
from typing import List
from typing import List, Optional
import struct

from hidet.ffi.utils import Array
from hidet.ir.type import void_p, DataType
from .ffi import nccl_runtime_api, NcclUniqueId
from hidet.cuda import Stream, current_stream
from .ffi import nccl_available, NcclUniqueId

NCCL_SPLIT_NOCOLOR = -1

if nccl_available():
from .ffi import nccl_runtime_api


class NcclDataType(IntEnum):
Expand Down Expand Up @@ -44,13 +50,20 @@ class NcclRedOp(IntEnum):
avg = 4


def str_to_nccl_op(name: str) -> NcclRedOp:
if name not in ('sum', 'prod', 'max', 'min', 'avg'):
raise RuntimeError(f"'{name}' is not a supported reduce op")
return getattr(NcclRedOp, name)


class NcclCommunicator:
def __init__(self, handle: int):
"""
Users should not call this constructor directly. Because there are two ways of creating
a new communicator: 1) using unique_id and rank ; 2) using split.
"""

if not nccl_available():
raise RuntimeError("NCCL is not available")
self._handle = handle

def __del__(self):
Expand All @@ -60,11 +73,25 @@ def __del__(self):
def handle(self):
return self._handle

def split(self):
raise NotImplementedError()
def split(self, key, color):
new_handle = nccl_runtime_api.comm_split(self._handle, color, key)
if color == NCCL_SPLIT_NOCOLOR:
return None
return NcclCommunicator(new_handle)

def all_reduce(
self, sendbuff: int, recvbuff: int, count: int, datatype: DataType, op: str, s: Optional[Stream] = None
):
if s is None:
s = current_stream()
nccl_runtime_api.all_reduce(
sendbuff, recvbuff, count, int(dtype_to_nccl(datatype)), int(str_to_nccl_op(op)), self._handle, s
)


def create_comm(nranks: int, unique_id: NcclUniqueId, rank: int) -> NcclCommunicator:
if not nccl_available():
raise RuntimeError("NCCL is not available")
handle = nccl_runtime_api.comm_init_rank(nranks, unique_id, rank)
return NcclCommunicator(handle)

Expand All @@ -76,8 +103,12 @@ def comms_to_array(comms: List[NcclCommunicator]) -> Array:
return array


def init_unique_id(unqie_id: NcclUniqueId) -> None:
nccl_runtime_api.get_unique_id(unqie_id)
def create_unique_id() -> NcclUniqueId:
if not nccl_available():
raise RuntimeError("NCCL is not available")
unique_id = NcclUniqueId()
nccl_runtime_api.get_unique_id(unique_id)
return unique_id


def dtype_to_nccl(dtype: DataType) -> NcclDataType:
Expand Down
125 changes: 80 additions & 45 deletions python/hidet/cuda/nccl/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@

from typing import Optional
import ctypes
from ctypes import c_void_p, c_int, pointer, Structure, c_byte, POINTER
from ctypes import c_void_p, c_int, pointer, Structure, c_byte, POINTER, c_uint64
import glob
import os

from hidet.ffi.ffi import get_func
from hidet.cuda import Stream
from .libinfo import get_nccl_library_search_dirs

_LIB_NCCL: Optional[ctypes.CDLL] = None
Expand Down Expand Up @@ -49,8 +50,6 @@ def load_nccl_library():
_LIB_NCCL = ctypes.cdll.LoadLibrary(lib_nccl_paths[0])
nccl_library_path = lib_nccl_paths[0]
break
if _LIB_NCCL is None:
raise OSError('Can not find nccl library in the following directory: \n' + '\n'.join(library_dirs))


load_nccl_library()
Expand All @@ -60,49 +59,85 @@ def nccl_library_filename():
return os.path.basename(nccl_library_path)


if not nccl_available():
raise RuntimeError("NCCL Library not found.")
if nccl_available():


class NCCLRuntimeAPI:
"""
Runtime APIs regarding NCCL
TODO: Exception handling
"""

_get_version = get_func('ncclGetVersion', [c_void_p], c_int, lib=_LIB_NCCL)
_get_unique_id = get_func('ncclGetUniqueId', [c_void_p], c_int, lib=_LIB_NCCL)
_comm_init_rank = get_func('ncclCommInitRank', [c_void_p, c_int, NcclUniqueId, c_int], c_int, lib=_LIB_NCCL)
_comm_destroy = get_func('ncclCommDestroy', [c_void_p], c_int, lib=_LIB_NCCL)

_comm_user_rank = get_func('ncclCommUserRank', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL)
_comm_count = get_func('ncclCommCount', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL)

@staticmethod
def get_version() -> int:
version = c_int(0)
NCCLRuntimeAPI._get_version(pointer(version))
return version.value

@staticmethod
def get_unique_id(comm_id: NcclUniqueId) -> None:
class NCCLRuntimeAPI:
"""
In-place initialization of the NcclUniqueId object
Runtime APIs regarding NCCL
TODO: Exception handling
"""
ret = NCCLRuntimeAPI._get_unique_id(pointer(comm_id))
assert ret == 0, ret

@staticmethod
def comm_init_rank(ndev: int, comm_id: NcclUniqueId, rank: int) -> int:
comm = c_void_p()
ret = NCCLRuntimeAPI._comm_init_rank(pointer(comm), ndev, comm_id, rank)
assert ret == 0, ret
return comm.value

@staticmethod
def comm_destroy(comm_handle) -> None:
ret = NCCLRuntimeAPI._comm_destroy(comm_handle)
assert ret == 0


nccl_runtime_api = NCCLRuntimeAPI()
_get_version = get_func('ncclGetVersion', [c_void_p], c_int, lib=_LIB_NCCL)
_get_unique_id = get_func('ncclGetUniqueId', [c_void_p], c_int, lib=_LIB_NCCL)
_comm_init_rank = get_func('ncclCommInitRank', [c_void_p, c_int, NcclUniqueId, c_int], c_int, lib=_LIB_NCCL)
_comm_destroy = get_func('ncclCommDestroy', [c_void_p], c_int, lib=_LIB_NCCL)

_comm_user_rank = get_func('ncclCommUserRank', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL)
_comm_count = get_func('ncclCommCount', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL)

_all_reduce = get_func(
'ncclAllReduce', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL
)
_broadcast = get_func(
'ncclBroadcast', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL
)
_reduce = get_func(
'ncclReduce', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL
)
_all_gather = get_func(
'ncclAllGather', [c_void_p, c_void_p, c_uint64, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL
)
_reduce_scatter = get_func(
'ncclReduceScatter', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL
)

# Early versions of NCCL do not have split
try:
_comm_split = get_func('ncclCommSplit', [c_void_p, c_int, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL)
except ValueError:
_comm_split = None

@staticmethod
def get_version() -> int:
version = c_int(0)
NCCLRuntimeAPI._get_version(pointer(version))
return version.value

@staticmethod
def get_unique_id(comm_id: NcclUniqueId) -> None:
"""
In-place initialization of the NcclUniqueId object
"""
ret = NCCLRuntimeAPI._get_unique_id(pointer(comm_id))
assert ret == 0, ret

@staticmethod
def comm_init_rank(ndev: int, comm_id: NcclUniqueId, rank: int) -> int:
comm = c_void_p()
ret = NCCLRuntimeAPI._comm_init_rank(pointer(comm), ndev, comm_id, rank)
assert ret == 0, ret
return comm.value

@staticmethod
def comm_destroy(comm_handle) -> None:
ret = NCCLRuntimeAPI._comm_destroy(comm_handle)
assert ret == 0

@staticmethod
def comm_split(comm_handle: int, color: int, key: int) -> int:
if NCCLRuntimeAPI._comm_split is None:
raise RuntimeError("split is not supported on this version of NCCL. Please install a newer version.")
comm = c_void_p()
ret = NCCLRuntimeAPI._comm_split(comm_handle, color, key, pointer(comm), None)
assert ret == 0
return comm.value

# TODO: Currently only support all_reduce
@staticmethod
def all_reduce(
sendbuff: int, recvbuff: int, count: int, datatype: int, op: int, comm_handle: int, s: Stream
) -> None:
ret = NCCLRuntimeAPI._all_reduce(sendbuff, recvbuff, count, datatype, op, comm_handle, c_void_p(int(s)))
assert ret == 0

nccl_runtime_api = NCCLRuntimeAPI()
Loading

0 comments on commit 7c52c9d

Please sign in to comment.