Skip to content

Commit

Permalink
uops loop removal (tinygrad#2262)
Browse files Browse the repository at this point in the history
* remove the loop

* cleanups

* tests failing still

* global_loop_ctx wasn't needed

* replace_op is cleaner

* minor opt

* cast opt was wrong

* uop_num

* uop num was dumb

* tuplize_uops

* torch tests

* fix test_uops
  • Loading branch information
geohot authored Nov 10, 2023
1 parent a753c8e commit 85d26dd
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 63 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ repos:
always_run: true
pass_filenames: false
- id: tests
name: subset of (CPU) tests
entry: env PYTHONPATH="." CPU=1 pytest test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py
name: subset of TORCH tests
entry: env PYTHONPATH="." TORCH=1 python3 -m pytest -n=4 test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py
language: system
always_run: true
pass_filenames: false
Expand Down
5 changes: 1 addition & 4 deletions test/external/fuzz_linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import Counter, defaultdict
from extra.optimization.helpers import load_worlds, ast_str_to_lin
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.search import get_linearizer_actions, bufs_from_lin
from tinygrad.features.search import get_linearizer_actions, bufs_from_lin, tuplize_uops
from tinygrad.graph import print_tree
from tinygrad.helpers import ImageDType, prod, getenv
from tinygrad.ops import Device, Compiled, Interpreted
Expand All @@ -26,10 +26,7 @@ def fuzz_linearizer(lin: Linearizer):
print(lin.colored_shape())
rawbufs = bufs_from_lin(lin)

# NOTE: copied from beam_search
def tuplize_uops(uops): return tuple([(x.uop, x.dtype, tuple(x.num for x in x.vin), x.arg) for x in uops])
seen_uops = {}

ground_truth = None
while 1:
if len(seen_uops) >= 20: break # enough for this kernel
Expand Down
6 changes: 6 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@ def test_arange_big(self):
def test_sum_collapse(self):
helper_test_op([], lambda: torch.ones(256,256).sum(axis=1), lambda: Tensor.ones(256,256).sum(axis=1), forward_only=True)

def test_sum_collapse_neg(self):
helper_test_op([], lambda: (-torch.ones(3,3)).sum(axis=1), lambda: (-Tensor.ones(3,3)).sum(axis=1), forward_only=True)

def test_max_dont_collapse(self):
helper_test_op([], lambda: torch.ones(256,256).max(1)[0], lambda: Tensor.ones(256,256).max(1), forward_only=True)

def test_where(self):
helper_test_op(
[(100,)],
Expand Down
2 changes: 1 addition & 1 deletion test/test_uops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def _uops_to_prg(uops):
runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime)

def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(vin), arg, len(uops)))
uops.append(UOp(uop, dtype, tuple(vin), arg))
return uops[-1]

def _test_single_value(vals, op, dtype):
Expand Down
134 changes: 82 additions & 52 deletions tinygrad/codegen/linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from enum import Enum, auto
from dataclasses import dataclass

from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, getenv
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, getenv, all_same
from tinygrad.ops import LazyOp, UnaryOps, ConstBuffer, MemBuffer, BufferOps
from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps
from tinygrad.shape.shapetracker import ShapeTracker
Expand All @@ -21,19 +21,13 @@ class UOps(Enum):
LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto(); PHI = auto() # noqa: E702
ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702

@dataclass
@dataclass(eq=False)
class UOp:
uop: UOps
dtype: Optional[DType]
vin: Tuple[UOp, ...]
arg: Any
def __repr__(self): return f"{self.num:4d} {str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.num for x in self.vin]):32s} {self.arg}"
#def __repr__(self): return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str(self.vin):32s} {self.arg}"

# UOps are unique
num: int
def __hash__(self): return self.num
def __eq__(self, x): return self.num == x.num
def __repr__(self): return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"

def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate(local_dims[0:maxdim-1] + (prod(local_dims[maxdim-1:]),) if len(local_dims) > maxdim else local_dims)]
Expand All @@ -52,7 +46,7 @@ def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op, dtype=dtypes.int32):
return self.uop(UOps.ALU, dtype, (a, render_b), op)

# NOTE: the consts have to be be cached for deduping of downstream uops to work
def const(self, b:Union[int,float], dtype=dtypes.int32) -> UOp: return self.uop(UOps.CONST, dtype, tuple(), b)
def const(self, b:Union[int,float], dtype=dtypes.int32, insert_before=None) -> UOp: return self.uop(UOps.CONST, dtype, tuple(), b, insert_before=insert_before)

render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b),
MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
Expand Down Expand Up @@ -215,7 +209,6 @@ def render_loop(xx:List[Variable]) -> Tuple[UOp, ...]:
# set global/local size
self.global_size: Optional[List[int]] = None
self.local_size: Optional[List[int]] = None
global_loop_ctx: Tuple[UOp, ...] = tuple()
if self.dont_use_locals:
self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)})
Expand All @@ -226,7 +219,7 @@ def render_loop(xx:List[Variable]) -> Tuple[UOp, ...]:
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)})
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
else:
global_loop_ctx = render_loop(loop_global_idxs+loop_local_idxs)
render_loop(loop_global_idxs+loop_local_idxs)

# parse AST
loaded_buffers = {}
Expand Down Expand Up @@ -302,7 +295,7 @@ def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
self.uop(UOps.CAST, dtypes._float8, tuple(op3)))
ret = self.uop(UOps.WMMA, dtypes._float2 if wmma_sz[2] == 2 else dtypes._float8, ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,))
for z in range(cast(DType, ret.dtype).sz):
acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + global_loop_ctx + loop_ctx)
acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + loop_ctx)
i += wmma_sz[2]
else:
if locals_to_store:
Expand All @@ -314,7 +307,7 @@ def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs})

# run early AST (with reduce)
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=global_loop_ctx + loop_ctx)
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx)

# end the reduce loop
self.load_cache.clear()
Expand Down Expand Up @@ -365,11 +358,54 @@ def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer})

# run late AST
val = self.ast_parse(self.ast, acc, None, loaded_buffers, loop_ctx=global_loop_ctx)
val = self.ast_parse(self.ast, acc, None, loaded_buffers)

# store
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)

# graph helper functions
def get_recursive_parents(x:List[UOp]) -> List[UOp]:
ret: Set[UOp] = set()
this_round: Set[UOp] = set(x)
while len(this_round):
ret = ret.union(this_round)
next_round: Set[UOp] = set()
for r in this_round: next_round = next_round.union(set(r.vin))
this_round = next_round
return list(ret)

def get_recursive_children(x:UOp) -> List[UOp]:
deps = set([x])
ssize = 0
while ssize != len(deps):
ssize = len(deps)
for u in self.uops:
if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])):
deps.add(u)
return sorted(list(deps), key=self.uops.index) # get the last one

def replace_op(old:UOp, new:UOp):
for u in self.uops:
u.vin = tuple(new if x is old else x for x in u.vin)
self.uops.remove(old)

# uops optimization
changed_something = True
while changed_something:
changed_something = False
for u in self.uops:
if u.uop == UOps.PHI and len(u.vin) == 3:
# if the parents of the PHI node don't have the LOOP in their parents, it can be folded
# TODO: ADD becomes a MUL, MAX can just become nothing
if all(x.uop != UOps.LOOP for x in get_recursive_parents(list(u.vin[0:2]))) and u.vin[1].arg == BinaryOps.ADD:
if DEBUG >= 4: print(f"removing PHI node {u}")
del self.saved_exprs[(u.uop, u.dtype, u.vin, u.arg)]
# NOTE: assuming u.vin[2].vin[1] and u.vin[2].vin[0] have the same dtype
loop_len = self.uop(UOps.ALU, u.vin[2].vin[1].dtype, (u.vin[2].vin[1], u.vin[2].vin[0]), BinaryOps.SUB, insert_before=self.uops.index(u))
if loop_len.dtype != u.dtype: loop_len = self.uop(UOps.CAST, u.dtype, (loop_len,), insert_before=self.uops.index(u))
replace_op(u, self.uop(UOps.ALU, u.dtype, (u.vin[1], loop_len,), BinaryOps.MUL, insert_before=self.uops.index(u)))
changed_something = True

# (recursively) remove childless uops
# NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_GLOBAL}
Expand All @@ -382,32 +418,21 @@ def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
if len(nu) == len(self.uops): break
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
self.uops = nu
del nu

def get_recursive_deps(x:UOp) -> List[UOp]:
deps = set([x])
ssize = 0
while ssize != len(deps):
ssize = len(deps)
for u in self.uops:
if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])):
deps.add(u)
return sorted(list(deps), key=lambda x: x.num)

# add END of loops after the last thing that (recursively) depends on them
# and END any if statements
# add UOps.END
for u in self.uops:
if u.uop == UOps.LOOP:
last_phi = self.uops.index(get_recursive_deps(u)[-1])
at_end = self.uops[last_phi+1:]
self.uops = self.uops[:last_phi+1]
self.uop(UOps.END, None, (u,), cachable=False)
self.uops += at_end
# add END of loops after the last thing that (recursively) depends on them
self.uop(UOps.END, None, (u,), cachable=False, insert_before=self.uops.index(get_recursive_children(u)[-1])+1)
elif u.uop == UOps.IF:
# END any if statements at the end of the uops
self.uop(UOps.END, None, (u,), cachable=False)

# maybe graph the uops
if DEBUG >= 5:
for u in self.uops: print(u)
for u in self.uops:
print(f"{self.uops.index(u):4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} {str([self.uops.index(x) for x in u.vin]):32s} {u.arg}")
if getenv("GRAPHUOPS"):
from tinygrad.graph import graph_uops
graph_uops(self.uops)
Expand All @@ -419,27 +444,32 @@ def get_recursive_deps(x:UOp) -> List[UOp]:
self.applied_opts_cache = self.applied_opts[:]
return self

def uop(self, uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None, cachable=True) -> UOp:
def uop(self, uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None, cachable=True, insert_before=None, simplify=True) -> UOp:
key = (uop, dtype, vin, arg)
if uop == UOps.PHI and len(vin) == 2 and vin[0] == vin[1]: return vin[0] # self phi is noop
if uop == UOps.GEP and vin[0].uop == UOps.CONST: return self.const(vin[0].arg, dtype)
if uop == UOps.ALU:
# rewrites. NOTE: the rewritten NEG op is still around...
if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable=cachable)
# constant folding
if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.const(-vin[0].arg, dtype)
# zero folding
for x in [0,1]:
if arg == BinaryOps.ADD and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 1.0: return vin[1-x]
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[x]
if arg == BinaryOps.SUB and vin[1].uop == UOps.CONST and vin[1].arg == 0.0: return vin[0]
if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0]
if simplify:
if uop == UOps.PHI and len(vin) == 2: return vin[1] # a phi without loops is a noop
if uop == UOps.GEP and vin[0].uop == UOps.CONST: return self.const(vin[0].arg, dtype, insert_before)
if uop == UOps.CAST and all(x.uop == UOps.CONST for x in vin) and all_same([x.arg for x in vin]): return self.const(vin[0].arg, dtype, insert_before)
if uop == UOps.ALU:
# rewrites. NOTE: the rewritten NEG op is still around...
if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable=cachable, insert_before=insert_before)
# constant folding
if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.const(-vin[0].arg, dtype, insert_before)
# zero folding
for x in [0,1]:
if arg == BinaryOps.ADD and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 1.0: return vin[1-x]
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[x]
if arg == BinaryOps.SUB and vin[1].uop == UOps.CONST and vin[1].arg == 0.0: return vin[0]
if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0]
if cachable and key in self.saved_exprs: return self.saved_exprs[key]
self.uops.append(UOp(uop, dtype, vin, arg, len(self.uops)))
#if DEBUG >= 5: print(self.uops[-1])
if cachable: self.saved_exprs[key] = self.uops[-1]
return self.uops[-1]
ret = UOp(uop, dtype, vin, arg)
if insert_before is not None:
self.uops.insert(insert_before, ret)
else:
self.uops.append(ret)
if cachable: self.saved_exprs[key] = ret
return ret

def ast_parse(self, x, acc, offs, loaded_buffers, do_reduce=False, loop_ctx=tuple()) -> List[UOp]:
if x.__class__ is not LazyOp: return loaded_buffers[x] # for LOCAL_BUFFER
Expand Down
5 changes: 3 additions & 2 deletions tinygrad/features/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from tinygrad.lazy import vars_from_ast
from tinygrad.ops import Device, Compiled, MemBuffer
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.codegen.linearizer import Linearizer, UOp
from tinygrad.runtime.lib import RawBuffer
from collections import defaultdict
from tinygrad.tensor import Tensor
Expand Down Expand Up @@ -97,6 +97,8 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
pass
return acted_lins

def tuplize_uops(uops:List[UOp]) -> Tuple: return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops])

def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer:
key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size}
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
Expand All @@ -108,7 +110,6 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
beam: List[Tuple[Linearizer, float]] = [(lin, time_linearizer(lin, rawbufs, allow_test_size=allow_test_size))]

# NOTE: real uops use a weird compare method that's only valid inside a linearizer
def tuplize_uops(uops): return tuple([(x.uop, x.dtype, tuple(x.num for x in x.vin), x.arg) for x in uops])
seen_uops = {tuplize_uops(lin.linearize().uops): tuple(lin.applied_opts)}

while 1:
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def graph_uops(uops):
G = nx.DiGraph()
for u in uops:
if u.uop == UOps.END: continue
G.add_node(u.num, label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff"))
for v in u.vin: G.add_edge(v.num, u.num)
G.add_node(uops.index(u), label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff"))
for v in u.vin: G.add_edge(uops.index(v), uops.index(u))
GRAPHPATH = "/tmp/uops"
nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
os.system(f'dot -Grankdir=LR -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg')

0 comments on commit 85d26dd

Please sign in to comment.