Skip to content

Commit

Permalink
Devicebufferless (tinygrad#708)
Browse files Browse the repository at this point in the history
* runs one metal kernel

* conv2d works

* ops tests are passing

* const folding

* all ops work

* pre commit always passes

* torch works

* working still

* fix graph test

* tests passing

* image almost works

* image conv works

* most images

* fix custom

* fix assignment

* fix compile enet

* clean up comments

* fix realize return value

* include shapetracker in LB repr

* copy should make a copy

* reenable method cache

* fix lna

* dtypes in graph

* forward only for IMAGE=2

* simple realize

* getting close

* fixup new api, it's good except the kernel count

* back to 197 kernels

* tests should pass

* go to a real float

* no type_on_cpu

* fix the docs

* put shapetracker back in it's proper place
  • Loading branch information
geohot authored Mar 18, 2023
1 parent 26a3888 commit f5467cf
Show file tree
Hide file tree
Showing 37 changed files with 466 additions and 441 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ jobs:
run: curl https://media.istockphoto.com/photos/hen-picture-id831791190 | ./recognize | grep hen

testllvm:
name: LLVM Tests
name: LLVM Tests (w method cache)
runs-on: ubuntu-latest

steps:
Expand Down Expand Up @@ -160,7 +160,9 @@ jobs:
- name: Install Dependencies
run: pip install -e '.[gpu,testing]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Test GPU IMAGE ops
run: GPU=1 IMAGE=2 python3 test/test_ops.py
run: |
GPU=1 IMAGE=1 python3 test/test_ops.py
FORWARD_ONLY=1 GPU=1 IMAGE=2 python3 test/test_ops.py
- name: Test openpilot model
run: |
ALLOWED_KERNEL_COUNT=197 FLOAT16=1 VALIDHACKS=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
Expand Down
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
repos:
- repo: local
hooks:
- id: docs
name: docs
entry: python3 docs/abstractions.py
language: system
always_run: true
pass_filenames: false
- id: flake8
name: flake8
entry: flake8 tinygrad/ --indent-size=2 --select=F,E112,E113,E203,E304,E502,E702,E703,E71,E72,E731,W191,W6 --statistics -j4
Expand Down
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -469,4 +469,4 @@ check-str-concat-over-line-jumps=yes

# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=Exception
overgeneral-exceptions=builtins.Exception
118 changes: 52 additions & 66 deletions docs/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,20 @@ class LazyBuffer:
shape: Tuple[int, ...]
dtype: DType

# a ShapeTracker is used to track things like reshapes and permutes
# all MovementOps are zero copy in tinygrad!
# the ShapeTracker specifies how the data in the RawBuffer matches to the shape
# we'll come back to this later
st: ShapeTracker

# if the LazyBuffer is realized, it has a RawBuffer
# we will come back to RawBuffers later
realized: Optional[RawBuffer]

# if the lazybuffer is unrealized, it has a LazyOp
# this LazyOp describes the computation needed to realize this LazyBuffer
op: Optional[LazyOp]

# if the LazyBuffer is realized, it has a DeviceBuffer
# we will come back to DeviceBuffers later, first we'll explore the LazyOp
realized: Optional[DeviceBuffer]

# LazyOp (in tinygrad/ops.py, code 4/10)
# in a tree they form an Abstract Syntax Tree for a single GPU kernel
class LazyOp:
Expand Down Expand Up @@ -128,81 +134,60 @@ class LoadOps(Enum): FROMCPU = auto()
# again, a LazyOp AST is like a GPU kernel. you have to copy the data on the device first
print(lazyop.src[0].op)
assert lazyop.src[0].op.op == LoadOps.FROMCPU
assert lazyop.src[0].op.arg[0] == [2], "the arg of the FROMCPU LazyOP is the [2.]"
assert lazyop.src[0].op.arg.fxn == [2], "the arg of the FROMCPU LazyOP is the [2.]"
assert result.lazydata.realized is None, "the LazyBuffer is not realized yet"

# now we realize the LazyBuffer
result.lazydata.realize()
assert result.lazydata.realized is not None, "the LazyBuffer is realized!"
# this brings us nicely to DeviceBuffer, of which the realized ClangBuffer is a subclass
assert 'ClangBuffer' in str(type(result.lazydata.realized))
assert 'RawMallocBuffer' in str(type(result.lazydata.realized))
# getting ahead of ourselves, but we can copy the DeviceBuffer toCPU
assert result.lazydata.realized.toCPU()[0] == 5, "when put in numpy with toCPU, it's 5"

# %%
# == DeviceBuffer (in tinygrad/ops.py, code 4/10) ==
# == Union[Interpreted, Compiled] (in tinygrad/ops.py, code 5/10) ==

# DeviceBuffer is an abstract class to be implemented for each Device backend
class DeviceBuffer(ABC):
# these two are straightforward.
# unlike LazyBuffer, there's no need for device, since that's contained in the concrete type
shape: Tuple[int, ...]
dtype: DType
# Now you have a choice, you can either write a "Interpreted" backend or "Compiled" backend

# this is the magic method that "fills" a DeviceBuffer and does all the math in tinygrad
# NOTE: fromCPU no longer exists here, it's just a one LoadOps AST, LoadOps.FROMCPU
def exec_ast(self, ast:LazyOp): raise NotImplementedError("must be implemented")
# Interpreted backends are very simple (example: CPU and TORCH)
class Interpreted:
# they have a backing RawBuffer
buffer: Type[RawBuffer]

# however, toCPU still exists. it will raise a RuntimeException if exec_ast has never been called
# it copies out the underlying to the CPU, and will do any sync operations
def toCPU(self) -> np.ndarray: raise NotImplementedError("must be implemented")
# and they have a lookup table to functions for the Ops
fxn_for_op: Dict[Op, Callable] = {
UnaryOps.EXP: lambda x: np.exp(x),
BinaryOps.ADD: lambda x,y: x+y}

# DeviceBuffers come in two flavors, InterpretedBuffer and CompiledBuffer
# InterpretedBuffers are a lot simpler than CompiledBuffers
# they are used to implement the CPU(numpy) and TORCH(torch) backends
# it's worth reading CPUBuffer (in tinygrad/runtime/ops_cpu.py, code 8/10)
import numpy as np
import torch
class InterpretedBuffer(DeviceBuffer):
# this is where the data actually lives
# finally some classes you recognize!
_buf: Union[np.ndarray, torch.Tensor]
# Compiled backends take a little more (example: GPU and LLVM)
class Compiled:
# they also have a backing RawBuffer
buffer: Type[RawBuffer]

# the compute itself is defined here. these functions are called with _buf
# here's a UnaryOp and BinaryOp from CPUBuffer(InterpretedBuffer)
fxn_for_op: ClassVar[Dict[Op, Callable]] = {UnaryOps.EXP: lambda x: np.exp(x), BinaryOps.ADD: lambda x,y: x+y}
# a code generator, which compiles the AST
codegen: Type[ASTKernel]

# NOTE: exec_ast should not need to be overridden!
# The actual method lives in tinygrad/ops.py
# it walks the LazyOp tree and calls fxn_for_op as appropriate
# and a runtime, which runs the generated code
runtime: Type[Runtime]

# ********** NOTE: for the CPU and TORCH backends, we are done and you can stop reading here **********
# Runtime is what actually runs the kernels for a compiled backend
class Runtime(ABC):
# `name` is the name of the function, and `prg` is the code
# the constructor compiles the code
def __init__(self, name:str, prg:str): pass
# call runs the code on the bufs. NOTE: the output is always bufs[0], but this is just a convention
def __call__(self, global_size:Optional[List[int]], local_size:Optional[List[int]], *bufs:List[RawBuffer]): pass

# %%
# == CompiledBuffer (in tinygrad/ops.py, code 4/10) ==

# however, all the magic of tinygrad will come from CompiledBuffer
# this is used for the GPU(opencl), CUDA, METAL, CLANG, and LLVM backends
class CompiledBuffer(DeviceBuffer):
# this is where the data actually lives, same as InterpretedBuffer
# a RawBuffer is just raw (typed) memory on the Device in question
_buf: RawBuffer

# introducing...ShapeTracker! all MovementOps are zero copy in tinygrad
# the ShapeTracker specifies how the data in the RawBuffer matches to the shape
# we'll come back to this later
st: ShapeTracker

# NOTE: exec_ast should not need to be overridden!
# instead you need three classes, explained below
raw_buffer: Type[RawBuffer]
runtime: Type[Runtime]
codegen: Type[ASTKernel]
# == RawBuffer (in tinygrad/runtime/lib.py, code 5/10) ==
import numpy as np

# for completeness, we include RawBuffer. it's very boring and exactly what you expect
# RawBuffer is where the data is actualy held. it's pretty close to just memory
class RawBuffer(ABC):
# create an empty rawbuffer that holds `size` elements of type `dtype`
def __init__(self, size:int, dtype:DType): raise NotImplementedError("must be implemented")
# `buf` is an opaque container class
def __init__(self, size:int, dtype:DType, buf:Any): raise NotImplementedError("must be implemented")

# fromCPU is classmethod that creates a RawBuffer, it's a classmethod since some runtimes are 0 copy
@classmethod
Expand All @@ -211,13 +196,14 @@ def fromCPU(cls:RawBuffer, x:np.ndarray) -> RawBuffer: raise NotImplementedError
# toCPU converts the RawBuffer to a numpy array with shape (size,). many backends are 0 copy here
def toCPU(self) -> np.ndarray: raise NotImplementedError("must be implemented")

# Runtime is what actually runs the kernels
class Runtime(ABC):
# `name` is the name of the function, and `prg` is the code
# the constructor compiles the code
def __init__(self, name:str, prg:str): pass
# call runs the code on the bufs. NOTE: the output is always bufs[0], but this is just a convention
def __call__(self, global_size:Optional[List[int]], local_size:Optional[List[int]], *bufs:List[RawBuffer]): pass
# RawNumpyBuffer is a RawBuffer example for numpy. It's very simple
class RawNumpyBuffer(RawBuffer):
# NOTE: the "np.ndarray" is stored in the opaque container
def __init__(self, buf:np.ndarray):
super().__init__(buf.size, dtypes.from_np(buf.dtype), buf)
@classmethod
def fromCPU(cls, x): return cls(x)
def toCPU(self): return self._buf

# %%
# == Example: 2+3 in raw clang ==
Expand Down Expand Up @@ -262,11 +248,11 @@ class ASTKernel:
def __init__(self, ast:LazyOp): pass
def codegen(self) -> ASTRunner: pass

# we return a class that runs code on CompiledBuffers
# we return a class that runs code on LazyBuffers, which are all expected to be realized
class ASTRunner: # (from tinygrad/ops.py)
def __init__(self, name, prg, global_size:Optional[List[int]], local_size:Optional[List[int]]): pass
def build(self, runtime:Runtime): pass
def exec(self, bufs:List[CompiledBuffer]): pass
def exec(self, bufs:List[LazyBuffer]): pass

# that hides a lot of complexity that will be refactored, but that's the basic idea of code generation

Expand Down
5 changes: 3 additions & 2 deletions examples/compile_efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ def run(x): return model.forward(x).realize()
# hack to put the inputs back
assert len(run.input_replace) == 1, f"didn't get one input to replace {run.input_replace}"
for (j,i),idx in run.input_replace.items():
run.jit_cache[j][1][i] = the_input.lazydata.realized.raw()
run.jit_cache[j][1][i] = the_input.lazydata.realized

# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
special_names = {id(the_input.lazydata.realized.raw()): "input", id(the_output.lazydata.realized.raw()): "outputs"}
special_names = {id(the_input.lazydata.realized): "input", id(the_output.lazydata.realized): "outputs"}

functions, statements, bufs, bufs_to_save = compile_net(run, special_names)

Expand Down Expand Up @@ -109,4 +109,5 @@ def run(x): return model.forward(x).realize()
}"""]

# CLANG=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && DEBUG=1 time ./recognize docs/stable_diffusion_by_tinygrad.jpg
# category : 281 (tabby, tabby cat) with 9.452788
print('\n'.join(cprog))
4 changes: 2 additions & 2 deletions extra/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from tinygrad.helpers import prod
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
from tinygrad.runtime.ops_gpu import GPUBuffer
from tinygrad.runtime.ops_gpu import CLBuffer
from tinygrad.ops import GlobalCounters

def print_objects():
#gc.collect()
tensors = [x for x in gc.get_objects() if isinstance(x, Tensor)]
tensor_ram_used = sum([prod(x.shape)*4 for x in tensors])
lazybuffers = [x for x in gc.get_objects() if isinstance(x, LazyBuffer)]
gpubuffers = [x for x in gc.get_objects() if isinstance(x, GPUBuffer)]
gpubuffers = [x for x in gc.get_objects() if isinstance(x, CLBuffer)]
realized_buffers = [x.realized for x in lazybuffers if x.realized]
gpubuffers_orphaned = [x for x in gpubuffers if x not in realized_buffers]

Expand Down
4 changes: 2 additions & 2 deletions extra/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def load_single_weight(t:Tensor, myfile, shape, strides, dtype, storage_offset,
# this needs real APIs
if t.device in ["METAL", "CLANG", "LLVM"]:
del t.lazydata.op
t.lazydata.realized = t.lazydata.dbuffer(t.shape, dtype=t.dtype)
myfile.readinto(t.lazydata.realized.raw()._buffer())
t.lazydata.realized = t.lazydata.dbuffer.buffer(prod(t.shape), dtype=t.dtype)
myfile.readinto(t.lazydata.realized._buffer())
else:
def _mmap(lna):
assert myfile._compress_type == 0, "compressed data can't be mmaped"
Expand Down
12 changes: 6 additions & 6 deletions openpilot/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
if os.getenv("IMAGE", None) is None:
os.environ['IMAGE'] = '2'

from tinygrad.helpers import getenv
from tinygrad.helpers import getenv, dtypes
ALLOWED_KERNEL_COUNT = getenv("ALLOWED_KERNEL_COUNT", 0)
DEBUGCL = getenv("DEBUGCL", 0)

Expand Down Expand Up @@ -38,7 +38,7 @@ def get_random_input_tensors(input_shapes):

@TinyJit
def model_exec(run_onnx, using_graph, **inputs):
ret = next(iter(run_onnx(inputs).values()))
ret = next(iter(run_onnx(inputs).values())).cast(dtypes.float32)
GlobalCounters.reset()
GlobalCounters.cache = [] # don't cache pre-realize
if using_graph: graph.GRAPH = True
Expand All @@ -49,7 +49,7 @@ def compile(dat, output_fn):
Tensor.manual_seed(1337)
Tensor.no_grad = True
using_graph = graph.GRAPH
graph.GRAPH = False
if getenv("GRAPH") < 2: graph.GRAPH = False

onnx_model = onnx.load(io.BytesIO(dat))
run_onnx = get_run_onnx(onnx_model)
Expand All @@ -63,7 +63,7 @@ def compile(dat, output_fn):
assert len(model_exec.jit_cache) <= ALLOWED_KERNEL_COUNT or ALLOWED_KERNEL_COUNT == 0, "too many kernels!"

# pull out inputs and put them in the jit cache
input_rawbuffers = {k:inputs[k].lazydata.realized.raw() for k in inputs.keys()}
input_rawbuffers = {k:inputs[k].lazydata.realized for k in inputs.keys()}
for (j,i),idx in model_exec.input_replace.items(): model_exec.jit_cache[j][1][i] = input_rawbuffers[idx]

# transform to CL.CACHE
Expand All @@ -73,11 +73,11 @@ def compile(dat, output_fn):
# pass these to thneed
setattr(prg.clprg, 'op_estimate', prg.op_estimate)
setattr(prg.clprg, 'prg', prg.prg)
cl_cache.append((prg.clprg, [prg.global_size, prg.local_size, *[x._cl for x in args]]))
cl_cache.append((prg.clprg, [prg.global_size, prg.local_size, *[x._buf for x in args]]))
used_ops += prg.op_estimate

from extra.thneed import Thneed
t = Thneed(cl_cache, {k:v._cl for k,v in input_rawbuffers.items()})
t = Thneed(cl_cache, {k:v._buf for k,v in input_rawbuffers.items()})

# save thneed (before run)
t.save(output_fn)
Expand Down
8 changes: 8 additions & 0 deletions test/external/external_test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,14 @@ def test_no_reduceop_rerun_alt(self):
np.testing.assert_allclose(c.numpy(), d.numpy().transpose(1,0), rtol=1e-3, atol=1e-5)
assert cache_len == 1, "reduceop was rerun!"

def test_fold_with_contiguous(self):
a = Tensor.randn(16, 16, 16)
b = Tensor.randn(16, 16)
with CLCache():
c = (a.sum(2).contiguous() + b).contiguous()
c.realize()
cache_len = len(GlobalCounters.cache)
assert cache_len == 1, "contiguous wasn't folded"

if __name__ == '__main__':
unittest.main()
5 changes: 3 additions & 2 deletions test/extra/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
from extra.utils import fetch, fake_torch_load_zipped
from PIL import Image

class TestUtils(unittest.TestCase):
class TestUtils(unittest.TestCase):
@unittest.skip("hangs sometimes")
def test_fetch_bad_http(self):
self.assertRaises(AssertionError, fetch, 'http://httpstat.us/500')
self.assertRaises(AssertionError, fetch, 'http://httpstat.us/404')
self.assertRaises(AssertionError, fetch, 'http://httpstat.us/400')

def test_fetch_small(self):
assert(len(fetch('https://google.com'))>0)

Expand Down
22 changes: 10 additions & 12 deletions test/test_custom_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,23 @@

# *** first, we implement the atan2 op at the lowest level ***
# `atan2_gpu` for GPUBuffers and `atan2_cpu` for CPUBuffers

from tinygrad.ops import ASTRunner, CompiledBuffer
from tinygrad.runtime.ops_cpu import CPUBuffer
from tinygrad.lazy import LazyBuffer, Device
from tinygrad.ops import ASTRunner

# we don't always have GPU support, so the type signature is the abstract CompiledBuffer instead of GPUBuffer
def atan2_gpu(a:CompiledBuffer, b:CompiledBuffer) -> CompiledBuffer:
from tinygrad.runtime.ops_gpu import GPUBuffer
assert type(a) == GPUBuffer and type(b) == GPUBuffer, "gpu function requires GPUBuffers"
def atan2_gpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
assert a.device == "GPU" and b.device == "GPU", "gpu function requires GPUBuffers"
assert a.dtype == b.dtype and a.dtype == dtypes.float32, "gpu function only supports float32"
ret = GPUBuffer(a.shape)
ret.realized = Device[ret.device].buffer(prod(ret.shape), ret.dtype)
ASTRunner("atan2", """
__kernel void atan2(global float *c, global float *a, global float *b) {
int idx = get_global_id(0);
c[idx] = atan2(a[idx], b[idx]);
}""", global_size=[prod(ret.shape)]).build(GPUBuffer.spec.runtime).exec([ret, a.contiguous(), b.contiguous()])
return ret
}""", global_size=[prod(ret.shape)]).build(Device[ret.device].runtime).exec([ret, a, b])
return ret.realized

def atan2_cpu(a:CPUBuffer, b:CPUBuffer) -> CPUBuffer:
return CPUBuffer(np.arctan2(a._buf, b._buf))
def atan2_cpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
return Device[ret.device].buffer(np.arctan2(a.realized._buf, b.realized._buf))

# *** second, we write the ATan2 mlop ***
# NOTE: The derivative of atan2 doesn't need a custom op! https://www.liquisearch.com/atan2/derivative
Expand All @@ -40,7 +38,7 @@ class ATan2(Function):
def forward(self, a:LazyBuffer, b:LazyBuffer) -> LazyBuffer:
assert prod(a.shape) == prod(b.shape) and a.device == b.device, "shape or device mismatch"
self.a, self.b = a, b
ast = LazyOp(LoadOps.CUSTOM, (a, b), {"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device])
ast = LazyOp(LoadOps.CUSTOM, (a.contiguous(), b.contiguous()), {"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device])
return LazyBuffer(a.device, a.shape, LoadOps, ast, max(a.dtype, b.dtype))
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
denom = (self.a.binary_op(BinaryOps.MUL, self.a)).binary_op(BinaryOps.ADD, self.b.binary_op(BinaryOps.MUL, self.b))
Expand Down
Loading

0 comments on commit f5467cf

Please sign in to comment.