Skip to content

Commit

Permalink
multidevice works (tinygrad#763)
Browse files Browse the repository at this point in the history
* basic multigpu working

* better multigpu test

* upper

* touchups

* cl sync
  • Loading branch information
geohot authored May 4, 2023
1 parent 4f6d674 commit f28df99
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 8 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ vertex.bin
recognize*
.idea
disassemblers/applegpu
disassemblers/cuda_ioctl_sniffer
*.prof
datasets/cifar-10-python.tar.gz
66 changes: 66 additions & 0 deletions test/external/external_multi_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#!/usr/bin/env python3
# cd disassemblers/ && git clone --recursive github.com:geohot/cuda_ioctl_sniffer.git
# LD_PRELOAD=$PWD/disassemblers/cuda_ioctl_sniffer/out/sniff.so GPU=1 python3 test/external/external_multi_gpu.py
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.helpers import colored
from extra.helpers import Timing
from tinygrad.runtime.ops_gpu import CL

# TODO: support multidevice in cuda
device = 'gpu'

if __name__ == "__main__":
sz = 1024*1024*256 # 1 GB
#sz = 1024*64

with Timing("CPU creation: ", on_exit=lambda x: f", {(sz*4*2)/x:.2f} GB/sec"):
c0 = Tensor.ones(sz, device="cpu").realize()
c1 = (Tensor.ones(sz, device="cpu")/2).realize()

with Timing("CPU -> 0: ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"):
a0 = c0.to(f'{device}:0').realize()
CL.synchronize()
with Timing("CPU -> 1: ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"):
b1 = c1.to(f'{device}:1').realize()
CL.synchronize()

# cross copy. this is going through the CPU
with Timing("0 -> 1: ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"):
a1 = a0.to(f'{device}:1').realize()
CL.synchronize()
with Timing("1 -> 0: ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"):
b0 = b1.to(f'{device}:0').realize()
CL.synchronize()

# sum
with Timing("0 -> 0 (sum): ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"):
ab0 = (a0 + b0).realize()
CL.synchronize()
with Timing("1 -> 1 (sum): ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"):
ab1 = (a1 + b1).realize()
CL.synchronize()

# cross device sum (does this work?)
# is this making a copy first? is that copy through the CPU?
# the slowness comes from the *blocking* clprg call, is this pyopencl?
with Timing(colored("0+1 -> 0 (sum): ", "red"), on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"):
abx0 = (a0 + b1).realize()
CL.synchronize()

with Timing(colored("1+0 -> 1 (sum): ", "red"), on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"):
abx1 = (b1 + a0).realize()
CL.synchronize()

# devices
print(ab0)
print(ab1)
print(abx0)
print(abx1)

# same
#print("testing")
#np.testing.assert_allclose(ab0.numpy(), ab1.numpy())
#np.testing.assert_allclose(ab0.numpy(), abx0.numpy())
#np.testing.assert_allclose(ab0.numpy(), abx1.numpy())

8 changes: 5 additions & 3 deletions tinygrad/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:
if GRAPH >= 3: log_op(self, self.op, phantom=True)

def __repr__(self): return f"<LB {self.shape} {self.dtype} op:{self.op.op if self.realized is None else self.realized} st:{self.st}>"
def _device_extra_args(self) -> Dict[str, int]: return {"device": int(self.device.split(":")[1])} if ":" in self.device else {}

def realize(self:LazyBuffer) -> LazyBuffer:
if self.realized is None:
Expand All @@ -110,7 +111,7 @@ def realize(self:LazyBuffer) -> LazyBuffer:
if prod(self.op.arg.shape) == 1 and hasattr(Device[self.device].codegen, 'supports_constant_folding'):
self.realized = RawConst(1, dtypes.from_np(self.op.arg.dtype), self.op.arg().flatten()[0])
else:
self.realized = Device[self.device].buffer.fromCPU(self.op.arg())
self.realized = Device[self.device].buffer.fromCPU(self.op.arg(), **self._device_extra_args())
elif self.op.op == LoadOps.CONTIGUOUS:
realized = self.op.src[0].realize().realized
if self.op.src[0].st.contiguous and not isinstance(realized, RawConst) and realized.size == prod(self.shape):
Expand Down Expand Up @@ -139,7 +140,7 @@ def realize(self:LazyBuffer) -> LazyBuffer:
self.op = LazyOp(UnaryOps.CAST, (self.op,), dtypes.float32)
self.dtype = dtypes.float32

self.realized = Device[self.device].exec_ast(self.op, output=self)
self.realized = Device[self.device].exec_ast(self.op, output=self, **self._device_extra_args())

assert isinstance(self.realized, (RawConst, Device[self.device].buffer)), f"device mismatch on realized got {type(self.realized)} expected {self.device}"
# HACK: allow hot casting of images
Expand Down Expand Up @@ -289,8 +290,9 @@ class _Device:
def __init__(self) -> None:
self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
self.DEFAULT: str = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, self._default_device())
def __getitem__(self, x:str) -> Union[Interpreted, Compiled]: return self._get_device(x.split(":")[0].upper())
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def __getitem__(self, x:str) -> Union[Interpreted, Compiled]: return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0]
def _get_device(self, x:str) -> Union[Interpreted, Compiled]: return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0]
def _default_device(self) -> str:
for device in ["METAL", "CUDA", "GPU"]:
try:
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(self, buffer: Type[RawBuffer], codegen, runtime, synchronize=lambda
self.buffer, self.codegen, self.runtime, self.synchronize = buffer, codegen, runtime, synchronize
self.method_cache: Dict[str, ASTRunner] = {}

def exec_ast(self, ast:LazyOp, output):
def exec_ast(self, ast:LazyOp, output, **kwargs):
# all movementops do nothing in a Compiled buffer!
if ast.op in MovementOps and not isinstance(ast.src[0], LazyOp) and ast.src[0].realized is not None: return ast.src[0].realized

Expand All @@ -145,7 +145,7 @@ def exec_ast(self, ast:LazyOp, output):

# we don't have an output buffer, we have to create it
if output.realized is None:
output.realized = self.buffer(prod(output.shape), output.dtype)
output.realized = self.buffer(prod(output.shape), output.dtype, **kwargs)

# compilation time
k = self.codegen(ast, output)
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/runtime/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class RawBufferCopyIn(RawBuffer):
def _copyin(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")

@classmethod
def fromCPU(cls, x:np.ndarray):
ret = cls(prod(x.shape), dtypes.from_np(x.dtype))
def fromCPU(cls, x:np.ndarray, **kwargs):
ret = cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs)
ret._copyin(x)
return ret

Expand Down
3 changes: 2 additions & 1 deletion tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class Tensor:
default_type: ClassVar[DType] = dtypes.float32

def __init__(self, data:Union[list, LazyBuffer, LazyNumpyArray, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
device = device.upper().replace(":0", "") # canonicalize device
if isinstance(data, list):
data = np.array(data, dtype=(dtype if dtype is not None else Tensor.default_type).np)
elif isinstance(data, LazyBuffer) and data.device != device:
Expand Down Expand Up @@ -68,7 +69,7 @@ def __init__(self, data:Union[list, LazyBuffer, LazyNumpyArray, np.ndarray], dev
self._ctx: Optional[Function] = None

def __repr__(self):
return f"<Tensor {self.lazydata if self.lazydata.realized is None else self.lazydata.realized!r} with grad {(self.grad.lazydata if self.grad else None)!r}>"
return f"<Tensor {self.lazydata if self.lazydata.realized is None else self.lazydata.realized!r} on {self.device} with grad {(self.grad.lazydata if self.grad else None)!r}>"

# Python has a non moving GC, so this should be okay
def __hash__(self): return id(self)
Expand Down

0 comments on commit f28df99

Please sign in to comment.