diff --git a/examples/distributed/test.py b/examples/distributed/test.py index bad43ef23..7fc1acb81 100644 --- a/examples/distributed/test.py +++ b/examples/distributed/test.py @@ -1,93 +1,65 @@ """ Testing script for distributed components for hidet To debug, set the environment variable NCCL_DEBUG=INFO + +To install nccl, run + + pip install nvidia-nccl-cu11==2.18.3 + +Or + + pip install nvidia-nccl-cu12==2.18.3 """ import hidet import multiprocessing from multiprocessing import Process import numpy import argparse +import atexit +import os +import hidet import hidet.cuda.nccl -from hidet.cuda import nccl -from hidet.cuda.nccl import NcclUniqueId, NcclDataType, NcclRedOp, nccl_library_filename -from hidet.ffi import runtime_api -from hidet.lang import attrs -from hidet.ir.primitives.cuda.nccl import all_reduce -from hidet.ir.type import data_type -from hidet.utils import prod -from hidet.drivers import build_ir_module -from hidet.cuda.nccl.libinfo import get_nccl_include_dirs, get_nccl_library_search_dirs -from hidet.runtime import load_compiled_module - -print("NCCL version:", nccl.nccl_version()) parser = argparse.ArgumentParser() parser.add_argument("n_gpus", type=int) parser.add_argument("reduce_op", choices=['sum', 'prod', 'max', 'min', 'avg']) args = parser.parse_args() -def run(world_size, rank, shared_id, barrier): +def run(world_size, rank): numpy.random.seed(rank) - # Initialize unique id - if rank == 0: - nccl.init_unique_id(shared_id) - - barrier.wait() hidet.cuda.set_device(rank) + hidet.distributed.init_process_group(init_method='file://tmp', world_size=world_size, rank=rank) + hidet.distributed.set_nccl_comms() - print('initialize', rank) - # Create NcclCommunicator and set the cuda context - # this part should be moved into CompiledGraph in the future - comm = nccl.create_comm(world_size, shared_id, rank) - comms_array = nccl.comms_to_array([comm]) - runtime_api.set_nccl_comms(comms_array) - - # Initialize send and receive buffer device = f"cuda:{rank}" - send = hidet.randn([2, 2], device=device) - recv = hidet.empty([2, 2], device=device) - - print(rank, send) - - dtype = data_type('float32') - shape = [2, 2] - nbytes = dtype.nbytes * prod(shape) - - # Define IRModule - with hidet.script_module() as script_module: - @hidet.script - def launch(send: dtype[shape], recv: dtype[shape]): - attrs.func_kind = 'public' - all_reduce(0, send, recv, nbytes, dtype, getattr(NcclRedOp, args.reduce_op)) - - # Build - ir_module = script_module.ir_module() - ir_module.target = 'cuda' - ir_module.include_dirs.extend(get_nccl_include_dirs()) - ir_module.linking_dirs.extend(get_nccl_library_search_dirs()) - ir_module.include_headers.append(["nccl.h"]) - ir_module.linking_libs.append(":" + nccl_library_filename()) - out_dir = f'./.cache/all_reduce_{rank}' - - build_ir_module(ir_module, out_dir, target='cuda') - compiled_module = load_compiled_module(out_dir) - - compiled_module(send, recv) - s = hidet.cuda.current_stream() - s.synchronize() - print(rank, recv) + x = hidet.randn([1, 3], device=device) + w = hidet.randn([3, 2], device=device) + + # test runtime distributed op + hidet.distributed.all_reduce(w, 'avg') + print(w) + + # Create Computation Graph + x_symb = hidet.symbol_like(x) + w_symb = hidet.symbol_like(w) + y_local = hidet.ops.relu(x_symb @ w_symb) + y_sync = hidet.ops.all_reduce(y_local, args.reduce_op) + graph = hidet.trace_from([y_local, y_sync], inputs=[x_symb, w_symb]) + opt_graph = hidet.graph.optimize(graph) + compiled = opt_graph.build() + y_local, y_sync = compiled(x, w) + + hidet.cuda.current_stream().synchronize() + print(f"process {rank}\nbefore allreduce:{y_local}\nafter allreduce:{y_sync}\n", end='') + atexit._run_exitfuncs() + +if os.path.exists('tmp'): + os.remove('tmp') world_size = args.n_gpus - -# Barrier to ensure unique id is created -barrier = multiprocessing.Barrier(world_size) - -# Create a unique id object in shared memory -shared_id = multiprocessing.Value(NcclUniqueId, lock=False) - -processes = [Process(target=run, args=(world_size, i, shared_id, barrier)) for i in range(world_size)] +processes = [Process(target=run, args=(world_size, i)) for i in range(world_size)] for p in processes: p.start() diff --git a/python/hidet/__init__.py b/python/hidet/__init__.py index 9c3206e25..94efae1f2 100644 --- a/python/hidet/__init__.py +++ b/python/hidet/__init__.py @@ -21,6 +21,7 @@ from . import drivers from . import logging from . import cuda +from . import distributed from .version import __version__ diff --git a/python/hidet/cuda/nccl/__init__.py b/python/hidet/cuda/nccl/__init__.py index 6ff476d71..ec8d1dfed 100644 --- a/python/hidet/cuda/nccl/__init__.py +++ b/python/hidet/cuda/nccl/__init__.py @@ -9,5 +9,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .comm import create_comm, NcclUniqueId, NcclDataType, NcclRedOp, comms_to_array, init_unique_id, dtype_to_nccl -from .ffi import nccl_version, nccl_library_filename +from .ffi import nccl_available, nccl_version, nccl_library_filename +from .comm import ( + create_comm, + NcclUniqueId, + NcclDataType, + NcclRedOp, + comms_to_array, + create_unique_id, + dtype_to_nccl, + NcclCommunicator, + str_to_nccl_op, + NCCL_SPLIT_NOCOLOR, +) diff --git a/python/hidet/cuda/nccl/comm.py b/python/hidet/cuda/nccl/comm.py index 241ffc18f..416d0f0cf 100644 --- a/python/hidet/cuda/nccl/comm.py +++ b/python/hidet/cuda/nccl/comm.py @@ -10,12 +10,18 @@ # See the License for the specific language governing permissions and # limitations under the License. from enum import IntEnum -from typing import List +from typing import List, Optional import struct from hidet.ffi.utils import Array from hidet.ir.type import void_p, DataType -from .ffi import nccl_runtime_api, NcclUniqueId +from hidet.cuda import Stream, current_stream +from .ffi import nccl_available, NcclUniqueId + +NCCL_SPLIT_NOCOLOR = -1 + +if nccl_available(): + from .ffi import nccl_runtime_api class NcclDataType(IntEnum): @@ -44,13 +50,20 @@ class NcclRedOp(IntEnum): avg = 4 +def str_to_nccl_op(name: str) -> NcclRedOp: + if name not in ('sum', 'prod', 'max', 'min', 'avg'): + raise RuntimeError(f"'{name}' is not a supported reduce op") + return getattr(NcclRedOp, name) + + class NcclCommunicator: def __init__(self, handle: int): """ Users should not call this constructor directly. Because there are two ways of creating a new communicator: 1) using unique_id and rank ; 2) using split. """ - + if not nccl_available(): + raise RuntimeError("NCCL is not available") self._handle = handle def __del__(self): @@ -60,11 +73,25 @@ def __del__(self): def handle(self): return self._handle - def split(self): - raise NotImplementedError() + def split(self, key, color): + new_handle = nccl_runtime_api.comm_split(self._handle, color, key) + if color == NCCL_SPLIT_NOCOLOR: + return None + return NcclCommunicator(new_handle) + + def all_reduce( + self, sendbuff: int, recvbuff: int, count: int, datatype: DataType, op: str, s: Optional[Stream] = None + ): + if s is None: + s = current_stream() + nccl_runtime_api.all_reduce( + sendbuff, recvbuff, count, int(dtype_to_nccl(datatype)), int(str_to_nccl_op(op)), self._handle, s + ) def create_comm(nranks: int, unique_id: NcclUniqueId, rank: int) -> NcclCommunicator: + if not nccl_available(): + raise RuntimeError("NCCL is not available") handle = nccl_runtime_api.comm_init_rank(nranks, unique_id, rank) return NcclCommunicator(handle) @@ -76,8 +103,12 @@ def comms_to_array(comms: List[NcclCommunicator]) -> Array: return array -def init_unique_id(unqie_id: NcclUniqueId) -> None: - nccl_runtime_api.get_unique_id(unqie_id) +def create_unique_id() -> NcclUniqueId: + if not nccl_available(): + raise RuntimeError("NCCL is not available") + unique_id = NcclUniqueId() + nccl_runtime_api.get_unique_id(unique_id) + return unique_id def dtype_to_nccl(dtype: DataType) -> NcclDataType: diff --git a/python/hidet/cuda/nccl/ffi.py b/python/hidet/cuda/nccl/ffi.py index 66a7dbc74..f415cb006 100644 --- a/python/hidet/cuda/nccl/ffi.py +++ b/python/hidet/cuda/nccl/ffi.py @@ -12,11 +12,12 @@ from typing import Optional import ctypes -from ctypes import c_void_p, c_int, pointer, Structure, c_byte, POINTER +from ctypes import c_void_p, c_int, pointer, Structure, c_byte, POINTER, c_uint64 import glob import os from hidet.ffi.ffi import get_func +from hidet.cuda import Stream from .libinfo import get_nccl_library_search_dirs _LIB_NCCL: Optional[ctypes.CDLL] = None @@ -49,8 +50,6 @@ def load_nccl_library(): _LIB_NCCL = ctypes.cdll.LoadLibrary(lib_nccl_paths[0]) nccl_library_path = lib_nccl_paths[0] break - if _LIB_NCCL is None: - raise OSError('Can not find nccl library in the following directory: \n' + '\n'.join(library_dirs)) load_nccl_library() @@ -60,49 +59,85 @@ def nccl_library_filename(): return os.path.basename(nccl_library_path) -if not nccl_available(): - raise RuntimeError("NCCL Library not found.") +if nccl_available(): - -class NCCLRuntimeAPI: - """ - Runtime APIs regarding NCCL - TODO: Exception handling - """ - - _get_version = get_func('ncclGetVersion', [c_void_p], c_int, lib=_LIB_NCCL) - _get_unique_id = get_func('ncclGetUniqueId', [c_void_p], c_int, lib=_LIB_NCCL) - _comm_init_rank = get_func('ncclCommInitRank', [c_void_p, c_int, NcclUniqueId, c_int], c_int, lib=_LIB_NCCL) - _comm_destroy = get_func('ncclCommDestroy', [c_void_p], c_int, lib=_LIB_NCCL) - - _comm_user_rank = get_func('ncclCommUserRank', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL) - _comm_count = get_func('ncclCommCount', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL) - - @staticmethod - def get_version() -> int: - version = c_int(0) - NCCLRuntimeAPI._get_version(pointer(version)) - return version.value - - @staticmethod - def get_unique_id(comm_id: NcclUniqueId) -> None: + class NCCLRuntimeAPI: """ - In-place initialization of the NcclUniqueId object + Runtime APIs regarding NCCL + TODO: Exception handling """ - ret = NCCLRuntimeAPI._get_unique_id(pointer(comm_id)) - assert ret == 0, ret - - @staticmethod - def comm_init_rank(ndev: int, comm_id: NcclUniqueId, rank: int) -> int: - comm = c_void_p() - ret = NCCLRuntimeAPI._comm_init_rank(pointer(comm), ndev, comm_id, rank) - assert ret == 0, ret - return comm.value - - @staticmethod - def comm_destroy(comm_handle) -> None: - ret = NCCLRuntimeAPI._comm_destroy(comm_handle) - assert ret == 0 - -nccl_runtime_api = NCCLRuntimeAPI() + _get_version = get_func('ncclGetVersion', [c_void_p], c_int, lib=_LIB_NCCL) + _get_unique_id = get_func('ncclGetUniqueId', [c_void_p], c_int, lib=_LIB_NCCL) + _comm_init_rank = get_func('ncclCommInitRank', [c_void_p, c_int, NcclUniqueId, c_int], c_int, lib=_LIB_NCCL) + _comm_destroy = get_func('ncclCommDestroy', [c_void_p], c_int, lib=_LIB_NCCL) + + _comm_user_rank = get_func('ncclCommUserRank', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL) + _comm_count = get_func('ncclCommCount', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL) + + _all_reduce = get_func( + 'ncclAllReduce', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL + ) + _broadcast = get_func( + 'ncclBroadcast', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL + ) + _reduce = get_func( + 'ncclReduce', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL + ) + _all_gather = get_func( + 'ncclAllGather', [c_void_p, c_void_p, c_uint64, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL + ) + _reduce_scatter = get_func( + 'ncclReduceScatter', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL + ) + + # Early versions of NCCL do not have split + try: + _comm_split = get_func('ncclCommSplit', [c_void_p, c_int, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL) + except ValueError: + _comm_split = None + + @staticmethod + def get_version() -> int: + version = c_int(0) + NCCLRuntimeAPI._get_version(pointer(version)) + return version.value + + @staticmethod + def get_unique_id(comm_id: NcclUniqueId) -> None: + """ + In-place initialization of the NcclUniqueId object + """ + ret = NCCLRuntimeAPI._get_unique_id(pointer(comm_id)) + assert ret == 0, ret + + @staticmethod + def comm_init_rank(ndev: int, comm_id: NcclUniqueId, rank: int) -> int: + comm = c_void_p() + ret = NCCLRuntimeAPI._comm_init_rank(pointer(comm), ndev, comm_id, rank) + assert ret == 0, ret + return comm.value + + @staticmethod + def comm_destroy(comm_handle) -> None: + ret = NCCLRuntimeAPI._comm_destroy(comm_handle) + assert ret == 0 + + @staticmethod + def comm_split(comm_handle: int, color: int, key: int) -> int: + if NCCLRuntimeAPI._comm_split is None: + raise RuntimeError("split is not supported on this version of NCCL. Please install a newer version.") + comm = c_void_p() + ret = NCCLRuntimeAPI._comm_split(comm_handle, color, key, pointer(comm), None) + assert ret == 0 + return comm.value + + # TODO: Currently only support all_reduce + @staticmethod + def all_reduce( + sendbuff: int, recvbuff: int, count: int, datatype: int, op: int, comm_handle: int, s: Stream + ) -> None: + ret = NCCLRuntimeAPI._all_reduce(sendbuff, recvbuff, count, datatype, op, comm_handle, c_void_p(int(s))) + assert ret == 0 + + nccl_runtime_api = NCCLRuntimeAPI() diff --git a/python/hidet/distributed/__init__.py b/python/hidet/distributed/__init__.py new file mode 100644 index 000000000..a5b06eb29 --- /dev/null +++ b/python/hidet/distributed/__init__.py @@ -0,0 +1,15 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .distributed import init_process_group, all_reduce +from .group import set_nccl_comms +from .store import FileStore diff --git a/python/hidet/distributed/distributed.py b/python/hidet/distributed/distributed.py new file mode 100644 index 000000000..b60e433dc --- /dev/null +++ b/python/hidet/distributed/distributed.py @@ -0,0 +1,105 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +from datetime import timedelta + +from hidet.graph import Tensor +from hidet.cuda.nccl import nccl_available +from .store import Store, FileStore +from .group import create_nccl_group, ProcessGroup + + +DEFAULT_TIMEOUT = timedelta(seconds=1800) + +DEFAULT_GROUP = None + + +def init_process_group( + backend: str = 'nccl', + init_method: Optional[str] = None, + store: Optional[Store] = None, + timeout: timedelta = DEFAULT_TIMEOUT, + world_size: int = -1, + rank: int = -1, +): + """ + We ues the same api as PyTorch. + Currently we only support FileStore. There are two ways to initialize via FileStore. + 1. Manually create a FileStore object and pass it as ``store``; + 2. Specify ``init_method`` with ``files://path-to-file``` + Now world_size and rank still need to be specified manually. + """ + global DEFAULT_GROUP + + if world_size <= 0 or rank < 0: + raise RuntimeError("'world_size' and 'rank' must be specified.") + + if rank >= world_size: + raise RuntimeError("'rank' must be smaller than 'world_size'") + + if store is None: + if init_method is None: + raise RuntimeError("One of 'init_method' and 'store' must be specified.") + else: + if not init_method.startswith('file://'): + raise RuntimeError( + "Currently only FileStore is supported. " + "Please speficy the path to the filestore with 'file://path-to-file'" + ) + path_to_file = init_method[len('file://') :] + store = FileStore(path_to_file) + else: + if init_method is not None: + raise RuntimeError("'init_method' and 'store' are mutually exclusive.") + + store.set_timeout(timeout) + if backend == 'nccl': + if not is_nccl_available(): + raise RuntimeError("NCCL is not found.") + DEFAULT_GROUP = create_nccl_group(store, world_size, rank) + else: + raise ValueError(f"Backend {backend} is not supported.") + + +def is_initialized(): + return DEFAULT_GROUP is not None + + +def is_nccl_available(): + return nccl_available() + + +def broadcast(): + raise NotImplementedError() + + +def all_reduce(tensor: Tensor, op: str, group: Optional[ProcessGroup] = None): + if group is None: + group = DEFAULT_GROUP + group.all_reduce(tensor, op) + + +def reduce(): + raise NotImplementedError() + + +def all_gather_into_tensor(): + raise NotImplementedError() + + +def scatter(): + raise NotImplementedError() + + +def reduce_scatter_tensor(): + raise NotImplementedError() diff --git a/python/hidet/distributed/group.py b/python/hidet/distributed/group.py new file mode 100644 index 000000000..d60de97dc --- /dev/null +++ b/python/hidet/distributed/group.py @@ -0,0 +1,104 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=W0223 + +from typing import Optional, List + +from hidet.graph import Tensor +from hidet.cuda.nccl import create_unique_id, NcclUniqueId, create_comm, NcclCommunicator, comms_to_array +from .store import Store + + +class ProcessGroup: + def backend(self) -> str: + raise NotImplementedError() + + def rank(self) -> int: + raise NotImplementedError() + + def size(self) -> int: + raise NotImplementedError() + + def broadcast(self, tensor: Tensor, src: int): + raise NotImplementedError() + + def all_reduce(self, tensor: Tensor, op: str): + raise NotImplementedError() + + def reduce(self, tensor: Tensor, dst: int, op: str): + raise NotImplementedError() + + def all_gather(self, tensor_list: List[Tensor], tensor: Tensor): + raise NotImplementedError() + + def all_gather_into_tensor(self, output_tensor: Tensor, input_tensor: Tensor): + raise NotImplementedError() + + def gather(self, tensor: Tensor, gather_list: Optional[List[Tensor]] = None, dst: int = 0): + raise NotImplementedError() + + def scatter(self, tensor: Tensor, scattler_list: Optional[List[Tensor]] = None): + raise NotImplementedError() + + def reduce_scatter(self, output: Tensor, input_list: List[Tensor], op: str): + raise NotImplementedError() + + def reduce_scatter_tensor(self, output: Tensor, input: Tensor, op: str): + raise NotImplementedError() + + def barrier(self): + raise NotImplementedError() + + +NCCL_COMMS: List[NcclCommunicator] = [] +_NCCL_ARRAY: 'Array' = None + + +class NCCLProcessGroup(ProcessGroup): + def __init__(self, comm: NcclCommunicator, world_size: int, rank: int): + self._comm: NcclCommunicator = comm + self._world_size: int = world_size + self._rank: int = rank + NCCL_COMMS.append(comm) + + def rank(self) -> int: + return self._rank + + def size(self) -> int: + return self._world_size + + def all_reduce(self, tensor: Tensor, op: str): + assert not tensor.is_symbolic() + assert tensor.device.is_cuda() + addr = tensor.storage.addr + self._comm.all_reduce(addr, addr, tensor.nbytes, tensor.dtype, op) + + +def create_nccl_group(store: Store, world_size: int, rank: int): + if rank == 0: + unique_id = create_unique_id() + store.set('unique_id', unique_id.internal) + else: + _id = store.get('unique_id') + unique_id = NcclUniqueId() + unique_id.internal[:] = _id[:] + comm = create_comm(world_size, unique_id, rank) + return NCCLProcessGroup(comm, world_size, rank) + + +def set_nccl_comms(): + global _NCCL_ARRAY + from hidet.ffi.runtime_api import runtime_api + + _NCCL_ARRAY = comms_to_array(NCCL_COMMS) + runtime_api.set_nccl_comms(_NCCL_ARRAY) diff --git a/python/hidet/distributed/store.py b/python/hidet/distributed/store.py new file mode 100644 index 000000000..5e53f2c2e --- /dev/null +++ b/python/hidet/distributed/store.py @@ -0,0 +1,215 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Dict +from datetime import timedelta, datetime +import time +import struct +import os +import atexit +import filelock + + +class Store: + def set(self, key: str, value: bytes) -> None: + raise NotImplementedError() + + def get(self, key: str) -> bytes: + raise NotImplementedError() + + def add(self, key: str, amount: int) -> int: + raise NotImplementedError() + + def compare_set(self, key: str, expected: bytes, desired: bytes) -> bytes: + raise NotImplementedError() + + def wait(self, keys: List[str], timeout: Optional[timedelta] = None) -> None: + raise NotImplementedError() + + def num_keys(self) -> int: + raise NotImplementedError() + + def delete_key(self, key: str) -> bool: + raise NotImplementedError() + + def set_timeout(self, timeout: timedelta): + raise NotImplementedError() + + +class FileStore(Store): + """ + A shared KV-store based on the local filesystem. + + It will create a binary file (specified by the filename argument) and a locking file. + Each time an new entry (key, value) is requested to be inserted, it will be inserted to + the end of the file. Only the newest is effective among all entries with the same key. + So when scanning the file from beginning, we can get the up-to-date status of the KV-store. + + All keys requested by public methods will be given a prefix '+' (REGULAR_PREFIX) to be + distinguished from some internal keys used by the store itself. For example, we have an + internal entry 'cnt' to maintain how many clients are using this store currently. + + Keys will be converted from Python strings to bytes automatically, while values won't since + values can be arbitary bytes arrays that might not be decodable. So please do the conversion + manually if required. + + We use a 4-byte integer to record the length of each (encoded) key and value. So do not insert + more than 2^31 - 1 bytes for each entry. + + Deletion of an entry is done by adding a new entry with a suffix '-' (DELETE_PREFIX). It will + overwrite the insertion of the given entry when we scanning the file. + """ + + REGULAR_PREFIX = '+' + DELETE_PREFIX = '-' + + def __init__(self, filename: str, world_size: Optional[int] = -1): + self._filename: str = filename + self._lock_filename: str = filename + '.lock' + self._world_size: int = world_size + + self._lock: filelock.FileLock = filelock.FileLock(self._lock_filename) + self._cache: Dict[str, bytes] = {} + self._timeout: timedelta = None + + num_peers = self._add('cnt', 1) + if 0 <= world_size < num_peers: + raise RuntimeError("Warning: more peers than world size.") + + # We cannot operate files in __del__, and we don't want to call close explicitly + # So we register a atexit function doing cleanup when python interpreter exits + @atexit.register + def cleanup(): + with self._lock: + if os.path.exists(self._filename): + rest = self._add('cnt', -1) + if rest == 0: + os.remove(self._filename) + + def _write(self, f, content): + f.write(struct.pack('i', len(content))) + f.write(content) + + def _read(self, f): + len_str = f.read(4) + if len_str == b'': + return None + l = struct.unpack('i', len_str)[0] + return f.read(l) + + def _file_size(self, f): + origin_pos = f.tell() + f.seek(0, 2) # 2 means the file's end + size = f.tell() + f.seek(origin_pos, 0) + return size + + def _update(self, f): + self._cache = {} + f.seek(0) + while True: + k = self._read(f) + if k is None: + return + v = self._read(f) + k = k.decode() + if k.startswith(self.DELETE_PREFIX): + k = k[len(self.DELETE_PREFIX) :] + del self._cache[k] + else: + self._cache[k] = v + + def _add(self, key: str, amount: int) -> int: + with self._lock: + with open(self._filename, "ab+") as f: + self._update(f) + value = int(self._cache.get(key, '0')) + amount + with open(self._filename, "ab+") as f: + self._write(f, bytes(key, encoding='utf-8')) + self._write(f, bytes(str(value), encoding='utf-8')) + return value + + def _check(self, keys: List[str]): + with self._lock: + with open(self._filename, "ab+") as f: + self._update(f) + return all((key in self._cache for key in keys)) + + def _set(self, key: str, value: bytes): + with self._lock: + with open(self._filename, "ab+") as f: + self._write(f, bytes(key, encoding='utf-8')) + self._write(f, value) + + def set(self, key: str, value: bytes) -> None: + self._set(self.REGULAR_PREFIX + key, value) + + def get(self, key: str) -> bytes: + last_file_size = None + key = self.REGULAR_PREFIX + key + start_t = datetime.now() + while True: + self._lock.acquire() + with open(self._filename, "ab+") as f: + file_size = self._file_size(f) + if key not in self._cache and file_size == last_file_size: + # No new entries + self._lock.release() + if self._timeout is not None and datetime.now() - start_t > self._timeout: + raise TimeoutError() + time.sleep(0.01) + continue + last_file_size = file_size + self._update(f) + self._lock.release() + value = self._cache.get(key) + if value is not None: + return value + + def add(self, key: str, amount: int) -> int: + return self._add(self.REGULAR_PREFIX + key, amount) + + def compare_set(self, key: str, expected: bytes, desired: bytes) -> bytes: + key = self.REGULAR_PREFIX + key + with self._lock: + with open(self._filename, "ab+") as f: + self._update(f) + has_key = key in self._cache + if (not has_key and expected == b'') or (has_key and self._cache[key] == expected): + f.seek(0, 2) + self._write(f, bytes(key, encoding='utf-8')) + self._write(f, desired) + return desired + elif not has_key: + return expected + return self._cache[key] + + def wait(self, keys: List[str], timeout: Optional[timedelta] = None) -> None: + timeout = self._timeout if self._timeout is not None else timeout + start_t = datetime.now() + keys = [self.REGULAR_PREFIX + key for key in keys] + while not self._check(keys): + if timeout is not None and datetime.now() - start_t > timeout: + raise TimeoutError() + time.sleep(0.01) + + def num_keys(self): + with self._lock(): + with open(self._filename, "rb") as f: + self._update(f) + return len(self._cache) + + def delete_key(self, key: str): + self._set(self.DELETE_PREFIX + self.REGULAR_PREFIX + key, b'') + + def set_timeout(self, timeout: timedelta): + self._timeout = timeout diff --git a/python/hidet/graph/flow_graph.py b/python/hidet/graph/flow_graph.py index 2286d99d0..996d02bff 100644 --- a/python/hidet/graph/flow_graph.py +++ b/python/hidet/graph/flow_graph.py @@ -193,6 +193,7 @@ def forward(self, inputs: List[Tensor]) -> List[Tensor]: output: List[Tensor] The output tensors of the computation graph. """ + from hidet.ffi import runtime_api inputs: List[Tensor] = list(inputs) diff --git a/python/hidet/graph/ops/__init__.py b/python/hidet/graph/ops/__init__.py index 599b403c9..6b2173b3b 100644 --- a/python/hidet/graph/ops/__init__.py +++ b/python/hidet/graph/ops/__init__.py @@ -52,5 +52,6 @@ from .fusion import fused_operator from .transfer import transfer from .special import barrier +from .distributed import all_reduce from . import utils diff --git a/python/hidet/graph/ops/distributed.py b/python/hidet/graph/ops/distributed.py new file mode 100644 index 000000000..3cc5a964b --- /dev/null +++ b/python/hidet/graph/ops/distributed.py @@ -0,0 +1,88 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Union, Tuple + +from hidet.ir.type import DataType +from hidet.ir.expr import Expr +from hidet.ir.module import IRModule +from hidet.ir.task import Target +from hidet.utils import prod +from hidet.cuda.nccl import str_to_nccl_op +from .utils import Task, TensorNode, Operator, Tensor, compute, input_like + + +class AllReduceTask(Task): + def __init__(self, x: TensorNode, op: str, comm_id: int = 0): + y = compute('out', x.shape, lambda *indices: x[indices]) + self.comm_id = comm_id + self.op = op + + super().__init__('all_reduce', inputs=[x], outputs=[y], attributes={'comm_id': comm_id, 'op': op}) + + def __str__(self): + return "all_reduce" + + def implement(self, target: Union[Target, str], working_dir: str) -> List[IRModule]: + import hidet + from hidet.ir.primitives.cuda.nccl import all_reduce as _all_reduce + from hidet.lang import attrs + + dtype: DataType = self.inputs[0].type.dtype + shape: Tuple[Expr, ...] = self.inputs[0].shape + nbytes = dtype.nbytes * prod(shape) + + with hidet.script_module() as script_module: + + @hidet.script + def launch(x: dtype[shape], y: dtype[shape]): + attrs.func_kind = 'public' + _all_reduce(x, y, nbytes, dtype, str_to_nccl_op(self.op), self.comm_id) + + return [script_module.ir_module()] + + +class AllReduceOp(Operator): + def __init__(self, x: Tensor, op: str, comm_id: int): + super().__init__( + inputs=[x], attributes={'op': op, 'comm_id': comm_id}, task=AllReduceTask(input_like(x, 'x'), op, comm_id) + ) + + +def all_reduce(x: Tensor, op: str, comm_id: int = 0) -> Tensor: + if x.device.kind != 'cuda': + raise RuntimeError("NCCL only supports CUDA tensors") + return AllReduceOp(x, op, comm_id).outputs[0] + + +def broadcast(x: Tensor, root: int, comm_id: int = 0) -> Tensor: + raise NotImplementedError() + + +def reduce(x: Tensor, root: int, op: str, comm_id: int = 0) -> Tensor: + raise NotImplementedError() + + +def all_gather(x: Tensor, comm_id: int = 0) -> Tensor: + raise NotImplementedError() + + +def reduce_scatter(x: Tensor, op: str, comm_id: int = 0) -> Tensor: + raise NotImplementedError() + + +def send(x: Tensor, peer: int, comm_id: int = 0) -> None: + raise NotImplementedError() + + +# Recv is a little bit tricky since we need to pass the metadata of the recv buffer +def recv(peer: int, comm_id: int = 0) -> Tensor: + raise NotImplementedError() diff --git a/python/hidet/ir/primitives/cuda/nccl.py b/python/hidet/ir/primitives/cuda/nccl.py index 96b2b2a2d..669d72322 100644 --- a/python/hidet/ir/primitives/cuda/nccl.py +++ b/python/hidet/ir/primitives/cuda/nccl.py @@ -18,7 +18,7 @@ from hidet.cuda.nccl import NcclRedOp, dtype_to_nccl -def all_reduce(comm_id: int, sendbuff: Expr, recvbuff: Expr, count: Expr, dtype: DataType, op: NcclRedOp): +def all_reduce(sendbuff: Expr, recvbuff: Expr, count: Expr, dtype: DataType, op: NcclRedOp, comm_id: int): from hidet.ir.primitives.runtime import get_cuda_stream, get_nccl_comm comm = get_nccl_comm(comm_id) diff --git a/python/hidet/runtime/compiled_graph.py b/python/hidet/runtime/compiled_graph.py index cce55954f..40406eb19 100644 --- a/python/hidet/runtime/compiled_graph.py +++ b/python/hidet/runtime/compiled_graph.py @@ -30,7 +30,6 @@ from hidet.ffi import runtime_api from hidet.utils import prod - ModelExecutionHook = Callable[[int, List['Tensor'], List['Tensor']], None] diff --git a/python/hidet/transforms/__init__.py b/python/hidet/transforms/__init__.py index f91be59ef..6d21b0f6f 100644 --- a/python/hidet/transforms/__init__.py +++ b/python/hidet/transforms/__init__.py @@ -36,6 +36,7 @@ from .propagate_launch_bound import propagate_launch_bound_pass from .check_launch_configuration import check_launch_configuration_pass from .lower_special_cast import lower_special_cast_pass +from .annotate_header_and_libs import annotate_header_and_libs_pass def lower_with(ir_module: IRModule, transforms: Sequence[Pass]) -> IRModule: @@ -80,5 +81,6 @@ def lower(ir_module: IRModule) -> IRModule: rule_based_simplify_pass(), inline_let_stmt_pass(), simplify_stmt_pass(), + annotate_header_and_libs_pass(), ] return lower_with(ir_module, transforms) diff --git a/python/hidet/transforms/annotate_header_and_libs.py b/python/hidet/transforms/annotate_header_and_libs.py new file mode 100644 index 000000000..53a974a8e --- /dev/null +++ b/python/hidet/transforms/annotate_header_and_libs.py @@ -0,0 +1,44 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import hidet.ir +from hidet.ir.module import IRModule +from hidet.ir.stmt import BlackBoxStmt +from hidet.transforms import Pass + + +def _use_distributed(func) -> bool: + black_stmts = hidet.ir.tools.collect(func.body, [BlackBoxStmt]) + return any(stmt.template_string.startswith('nccl') for stmt in black_stmts) + + +class AnnotateHeaderAndLibsPass(Pass): + def process_module(self, ir_module: IRModule) -> IRModule: + use_dist = any(_use_distributed(func) for func in ir_module.functions.values()) + if not use_dist: + return ir_module + + from hidet.cuda.nccl.libinfo import get_nccl_include_dirs, get_nccl_library_search_dirs + from hidet.cuda.nccl import nccl_available, nccl_library_filename + + if not nccl_available(): + raise RuntimeError("NCCL is not available") + + new_module = ir_module.copy() + new_module.include_dirs.extend(get_nccl_include_dirs()) + new_module.linking_dirs.extend(get_nccl_library_search_dirs()) + new_module.include_headers.append(["nccl.h"]) + new_module.linking_libs.append(":" + nccl_library_filename()) + return new_module + + +def annotate_header_and_libs_pass(): + return AnnotateHeaderAndLibsPass() diff --git a/requirements.txt b/requirements.txt index 563176e72..5af4f488c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,3 +31,6 @@ packaging # for cuda runtime api and runtime compilation api cuda-python + +# for filestore +filelock diff --git a/tests/distributed/test_file_store.py b/tests/distributed/test_file_store.py new file mode 100644 index 000000000..8ae8ede1f --- /dev/null +++ b/tests/distributed/test_file_store.py @@ -0,0 +1,137 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import multiprocessing +from multiprocessing import Process, Queue +import os +import time +from datetime import timedelta +import random + +from hidet.distributed import FileStore + +TMP_PATH = './tmp' + + +def test_filestore_get_hold(): + if os.path.exists(TMP_PATH): + os.remove(TMP_PATH) + + def subproc(): + store = FileStore(TMP_PATH) + store.get('non-existing-key') + + p = Process(target=subproc) + p.start() + store = FileStore(TMP_PATH) + store.set('key', b'value') + time.sleep(1) + assert p.is_alive() + p.terminate() + + +def test_filestore_set_get(): + if os.path.exists(TMP_PATH): + os.remove(TMP_PATH) + + def subproc(q): + store = FileStore(TMP_PATH) + store.set_timeout(timedelta(seconds=10)) + b = store.get('key') + q.put(b) + + store = FileStore(TMP_PATH) + store.set('key', b'u98guj89ks') + new_value = b'32894728934798' + store.set('key', new_value) + q = Queue() + p = Process(target=subproc, args=(q,)) + p.start() + ret = q.get() + assert ret == new_value + p.join() + + +def test_filestore_add(): + if os.path.exists(TMP_PATH): + os.remove(TMP_PATH) + + def subproc(): + store = FileStore(TMP_PATH) + store.add('cnt', 1) + store.add('cnt', 2) + + store = FileStore(TMP_PATH) + store.add('cnt', 1) + p = Process(target=subproc) + p.start() + p.join() + ret = store.add('cnt', 2) + assert ret == 6 + + +def test_filestore_del(): + if os.path.exists(TMP_PATH): + os.remove(TMP_PATH) + + def subproc(): + store = FileStore(TMP_PATH) + store.get('key') + + p = Process(target=subproc) + p.start() + store = FileStore(TMP_PATH) + store.set('key', b'value') + store.delete_key('key') + time.sleep(1) + assert p.is_alive() + p.terminate() + + +def test_filestore_wait(): + if os.path.exists(TMP_PATH): + os.remove(TMP_PATH) + + def subproc(): + store = FileStore(TMP_PATH) + store.wait(['key'], timeout=timedelta(seconds=10)) + + p = Process(target=subproc) + p.start() + store = FileStore(TMP_PATH) + time.sleep(1) + assert p.is_alive() + store.set('key', b'test') + p.join() + assert not p.is_alive() + + +def test_filestore_compare_set(): + if os.path.exists(TMP_PATH): + os.remove(TMP_PATH) + + def subproc(): + store = FileStore(TMP_PATH) + store.compare_set("key", b"first", b"second") + + store = FileStore(TMP_PATH) + store.set("key", b"random") + p = Process(target=subproc) + p.start() + p.join() + assert store.get("key") == b"random" + store.set("key", b"first") + store.compare_set("key", b"first", b"second") + p = Process(target=subproc) + p.start() + p.join() + assert store.get("key") == b"second"