diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2cc71e29050d1..a339071b04155 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,8 +26,8 @@ repos: always_run: true pass_filenames: false - id: tests - name: subset of (CPU) tests - entry: env PYTHONPATH="." CPU=1 pytest test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py + name: subset of TORCH tests + entry: env PYTHONPATH="." TORCH=1 python3 -m pytest -n=4 test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py language: system always_run: true pass_filenames: false diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 8f49584fa972d..7cf22b1e0f63b 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -3,7 +3,7 @@ from collections import Counter, defaultdict from extra.optimization.helpers import load_worlds, ast_str_to_lin from tinygrad.codegen.linearizer import Linearizer -from tinygrad.features.search import get_linearizer_actions, bufs_from_lin +from tinygrad.features.search import get_linearizer_actions, bufs_from_lin, tuplize_uops from tinygrad.graph import print_tree from tinygrad.helpers import ImageDType, prod, getenv from tinygrad.ops import Device, Compiled, Interpreted @@ -26,10 +26,7 @@ def fuzz_linearizer(lin: Linearizer): print(lin.colored_shape()) rawbufs = bufs_from_lin(lin) - # NOTE: copied from beam_search - def tuplize_uops(uops): return tuple([(x.uop, x.dtype, tuple(x.num for x in x.vin), x.arg) for x in uops]) seen_uops = {} - ground_truth = None while 1: if len(seen_uops) >= 20: break # enough for this kernel diff --git a/test/test_ops.py b/test/test_ops.py index 2d003e4e9c68c..4c1ecaf417776 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -140,6 +140,12 @@ def test_arange_big(self): def test_sum_collapse(self): helper_test_op([], lambda: torch.ones(256,256).sum(axis=1), lambda: Tensor.ones(256,256).sum(axis=1), forward_only=True) + def test_sum_collapse_neg(self): + helper_test_op([], lambda: (-torch.ones(3,3)).sum(axis=1), lambda: (-Tensor.ones(3,3)).sum(axis=1), forward_only=True) + + def test_max_dont_collapse(self): + helper_test_op([], lambda: torch.ones(256,256).max(1)[0], lambda: Tensor.ones(256,256).max(1), forward_only=True) + def test_where(self): helper_test_op( [(100,)], diff --git a/test/test_uops.py b/test/test_uops.py index a18523b45f4d9..2d0554537a9ab 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -13,7 +13,7 @@ def _uops_to_prg(uops): runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime) def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp: - uops.append(UOp(uop, dtype, tuple(vin), arg, len(uops))) + uops.append(UOp(uop, dtype, tuple(vin), arg)) return uops[-1] def _test_single_value(vals, op, dtype): diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 4de88e48f66b9..da91b05626118 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -5,7 +5,7 @@ from enum import Enum, auto from dataclasses import dataclass -from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, getenv +from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, getenv, all_same from tinygrad.ops import LazyOp, UnaryOps, ConstBuffer, MemBuffer, BufferOps from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps from tinygrad.shape.shapetracker import ShapeTracker @@ -21,19 +21,13 @@ class UOps(Enum): LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto(); PHI = auto() # noqa: E702 ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702 -@dataclass +@dataclass(eq=False) class UOp: uop: UOps dtype: Optional[DType] vin: Tuple[UOp, ...] arg: Any - def __repr__(self): return f"{self.num:4d} {str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.num for x in self.vin]):32s} {self.arg}" - #def __repr__(self): return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str(self.vin):32s} {self.arg}" - - # UOps are unique - num: int - def __hash__(self): return self.num - def __eq__(self, x): return self.num == x.num + def __repr__(self): return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}" def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0): local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate(local_dims[0:maxdim-1] + (prod(local_dims[maxdim-1:]),) if len(local_dims) > maxdim else local_dims)] @@ -52,7 +46,7 @@ def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op, dtype=dtypes.int32): return self.uop(UOps.ALU, dtype, (a, render_b), op) # NOTE: the consts have to be be cached for deduping of downstream uops to work - def const(self, b:Union[int,float], dtype=dtypes.int32) -> UOp: return self.uop(UOps.CONST, dtype, tuple(), b) + def const(self, b:Union[int,float], dtype=dtypes.int32, insert_before=None) -> UOp: return self.uop(UOps.CONST, dtype, tuple(), b, insert_before=insert_before) render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b), MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL), @@ -215,7 +209,6 @@ def render_loop(xx:List[Variable]) -> Tuple[UOp, ...]: # set global/local size self.global_size: Optional[List[int]] = None self.local_size: Optional[List[int]] = None - global_loop_ctx: Tuple[UOp, ...] = tuple() if self.dont_use_locals: self.global_size = [x.max+1 for x in loop_global_idxs][::-1] self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) @@ -226,7 +219,7 @@ def render_loop(xx:List[Variable]) -> Tuple[UOp, ...]: self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)}) else: - global_loop_ctx = render_loop(loop_global_idxs+loop_local_idxs) + render_loop(loop_global_idxs+loop_local_idxs) # parse AST loaded_buffers = {} @@ -302,7 +295,7 @@ def calc_tc_idxs(local_size: int, aliases: List[List[int]]): self.uop(UOps.CAST, dtypes._float8, tuple(op3))) ret = self.uop(UOps.WMMA, dtypes._float2 if wmma_sz[2] == 2 else dtypes._float8, ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,)) for z in range(cast(DType, ret.dtype).sz): - acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + global_loop_ctx + loop_ctx) + acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + loop_ctx) i += wmma_sz[2] else: if locals_to_store: @@ -314,7 +307,7 @@ def calc_tc_idxs(local_size: int, aliases: List[List[int]]): loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs}) # run early AST (with reduce) - self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=global_loop_ctx + loop_ctx) + self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # end the reduce loop self.load_cache.clear() @@ -365,11 +358,54 @@ def calc_tc_idxs(local_size: int, aliases: List[List[int]]): loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer}) # run late AST - val = self.ast_parse(self.ast, acc, None, loaded_buffers, loop_ctx=global_loop_ctx) + val = self.ast_parse(self.ast, acc, None, loaded_buffers) # store self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val) + # graph helper functions + def get_recursive_parents(x:List[UOp]) -> List[UOp]: + ret: Set[UOp] = set() + this_round: Set[UOp] = set(x) + while len(this_round): + ret = ret.union(this_round) + next_round: Set[UOp] = set() + for r in this_round: next_round = next_round.union(set(r.vin)) + this_round = next_round + return list(ret) + + def get_recursive_children(x:UOp) -> List[UOp]: + deps = set([x]) + ssize = 0 + while ssize != len(deps): + ssize = len(deps) + for u in self.uops: + if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])): + deps.add(u) + return sorted(list(deps), key=self.uops.index) # get the last one + + def replace_op(old:UOp, new:UOp): + for u in self.uops: + u.vin = tuple(new if x is old else x for x in u.vin) + self.uops.remove(old) + + # uops optimization + changed_something = True + while changed_something: + changed_something = False + for u in self.uops: + if u.uop == UOps.PHI and len(u.vin) == 3: + # if the parents of the PHI node don't have the LOOP in their parents, it can be folded + # TODO: ADD becomes a MUL, MAX can just become nothing + if all(x.uop != UOps.LOOP for x in get_recursive_parents(list(u.vin[0:2]))) and u.vin[1].arg == BinaryOps.ADD: + if DEBUG >= 4: print(f"removing PHI node {u}") + del self.saved_exprs[(u.uop, u.dtype, u.vin, u.arg)] + # NOTE: assuming u.vin[2].vin[1] and u.vin[2].vin[0] have the same dtype + loop_len = self.uop(UOps.ALU, u.vin[2].vin[1].dtype, (u.vin[2].vin[1], u.vin[2].vin[0]), BinaryOps.SUB, insert_before=self.uops.index(u)) + if loop_len.dtype != u.dtype: loop_len = self.uop(UOps.CAST, u.dtype, (loop_len,), insert_before=self.uops.index(u)) + replace_op(u, self.uop(UOps.ALU, u.dtype, (u.vin[1], loop_len,), BinaryOps.MUL, insert_before=self.uops.index(u))) + changed_something = True + # (recursively) remove childless uops # NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_GLOBAL} @@ -382,32 +418,21 @@ def calc_tc_idxs(local_size: int, aliases: List[List[int]]): if len(nu) == len(self.uops): break if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}") self.uops = nu + del nu - def get_recursive_deps(x:UOp) -> List[UOp]: - deps = set([x]) - ssize = 0 - while ssize != len(deps): - ssize = len(deps) - for u in self.uops: - if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])): - deps.add(u) - return sorted(list(deps), key=lambda x: x.num) - - # add END of loops after the last thing that (recursively) depends on them - # and END any if statements + # add UOps.END for u in self.uops: if u.uop == UOps.LOOP: - last_phi = self.uops.index(get_recursive_deps(u)[-1]) - at_end = self.uops[last_phi+1:] - self.uops = self.uops[:last_phi+1] - self.uop(UOps.END, None, (u,), cachable=False) - self.uops += at_end + # add END of loops after the last thing that (recursively) depends on them + self.uop(UOps.END, None, (u,), cachable=False, insert_before=self.uops.index(get_recursive_children(u)[-1])+1) elif u.uop == UOps.IF: + # END any if statements at the end of the uops self.uop(UOps.END, None, (u,), cachable=False) # maybe graph the uops if DEBUG >= 5: - for u in self.uops: print(u) + for u in self.uops: + print(f"{self.uops.index(u):4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} {str([self.uops.index(x) for x in u.vin]):32s} {u.arg}") if getenv("GRAPHUOPS"): from tinygrad.graph import graph_uops graph_uops(self.uops) @@ -419,27 +444,32 @@ def get_recursive_deps(x:UOp) -> List[UOp]: self.applied_opts_cache = self.applied_opts[:] return self - def uop(self, uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None, cachable=True) -> UOp: + def uop(self, uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None, cachable=True, insert_before=None, simplify=True) -> UOp: key = (uop, dtype, vin, arg) - if uop == UOps.PHI and len(vin) == 2 and vin[0] == vin[1]: return vin[0] # self phi is noop - if uop == UOps.GEP and vin[0].uop == UOps.CONST: return self.const(vin[0].arg, dtype) - if uop == UOps.ALU: - # rewrites. NOTE: the rewritten NEG op is still around... - if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable=cachable) - # constant folding - if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.const(-vin[0].arg, dtype) - # zero folding - for x in [0,1]: - if arg == BinaryOps.ADD and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[1-x] - if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 1.0: return vin[1-x] - if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[x] - if arg == BinaryOps.SUB and vin[1].uop == UOps.CONST and vin[1].arg == 0.0: return vin[0] - if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0] + if simplify: + if uop == UOps.PHI and len(vin) == 2: return vin[1] # a phi without loops is a noop + if uop == UOps.GEP and vin[0].uop == UOps.CONST: return self.const(vin[0].arg, dtype, insert_before) + if uop == UOps.CAST and all(x.uop == UOps.CONST for x in vin) and all_same([x.arg for x in vin]): return self.const(vin[0].arg, dtype, insert_before) + if uop == UOps.ALU: + # rewrites. NOTE: the rewritten NEG op is still around... + if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable=cachable, insert_before=insert_before) + # constant folding + if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.const(-vin[0].arg, dtype, insert_before) + # zero folding + for x in [0,1]: + if arg == BinaryOps.ADD and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[1-x] + if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 1.0: return vin[1-x] + if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[x] + if arg == BinaryOps.SUB and vin[1].uop == UOps.CONST and vin[1].arg == 0.0: return vin[0] + if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0] if cachable and key in self.saved_exprs: return self.saved_exprs[key] - self.uops.append(UOp(uop, dtype, vin, arg, len(self.uops))) - #if DEBUG >= 5: print(self.uops[-1]) - if cachable: self.saved_exprs[key] = self.uops[-1] - return self.uops[-1] + ret = UOp(uop, dtype, vin, arg) + if insert_before is not None: + self.uops.insert(insert_before, ret) + else: + self.uops.append(ret) + if cachable: self.saved_exprs[key] = ret + return ret def ast_parse(self, x, acc, offs, loaded_buffers, do_reduce=False, loop_ctx=tuple()) -> List[UOp]: if x.__class__ is not LazyOp: return loaded_buffers[x] # for LOCAL_BUFFER diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index c48a3de482fe1..9d402cf8247d9 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -3,7 +3,7 @@ from tinygrad.lazy import vars_from_ast from tinygrad.ops import Device, Compiled, MemBuffer from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context -from tinygrad.codegen.linearizer import Linearizer +from tinygrad.codegen.linearizer import Linearizer, UOp from tinygrad.runtime.lib import RawBuffer from collections import defaultdict from tinygrad.tensor import Tensor @@ -97,6 +97,8 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz pass return acted_lins +def tuplize_uops(uops:List[UOp]) -> Tuple: return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops]) + def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer: key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size} if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1: @@ -108,7 +110,6 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea beam: List[Tuple[Linearizer, float]] = [(lin, time_linearizer(lin, rawbufs, allow_test_size=allow_test_size))] # NOTE: real uops use a weird compare method that's only valid inside a linearizer - def tuplize_uops(uops): return tuple([(x.uop, x.dtype, tuple(x.num for x in x.vin), x.arg) for x in uops]) seen_uops = {tuplize_uops(lin.linearize().uops): tuple(lin.applied_opts)} while 1: diff --git a/tinygrad/graph.py b/tinygrad/graph.py index 0ca9d2ec37635..4b4bbc4ac947b 100644 --- a/tinygrad/graph.py +++ b/tinygrad/graph.py @@ -114,8 +114,8 @@ def graph_uops(uops): G = nx.DiGraph() for u in uops: if u.uop == UOps.END: continue - G.add_node(u.num, label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff")) - for v in u.vin: G.add_edge(v.num, u.num) + G.add_node(uops.index(u), label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff")) + for v in u.vin: G.add_edge(uops.index(v), uops.index(u)) GRAPHPATH = "/tmp/uops" nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot') os.system(f'dot -Grankdir=LR -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg')