Skip to content

Commit

Permalink
hotfix: tasteful ctrl-c in parallel beam
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot committed Dec 5, 2023
1 parent 35b5e95 commit ec594cf
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 27 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repos:
pass_filenames: false
- id: devicetests
name: select GPU tests
entry: env GPU=1 PYTHONPATH="." pytest test/test_uops.py test/test_custom_function.py
entry: env GPU=1 PYTHONPATH="." pytest test/test_uops.py test/test_custom_function.py test/test_search.py
language: system
always_run: true
pass_filenames: false
Expand Down
59 changes: 33 additions & 26 deletions tinygrad/features/search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
import itertools, random, math, time, multiprocessing, traceback
import itertools, random, math, time, multiprocessing, traceback, signal
from tinygrad.lazy import vars_from_ast
from tinygrad.device import Device, Compiled, Buffer
from tinygrad.ops import MemBuffer
Expand Down Expand Up @@ -101,6 +101,9 @@ def time_program(dev:str, lib:bytes, global_size, local_size, var_vals, rawbufs,
if early_stop is not None and early_stop < tms[-1]: break
return tms

# workers should ignore ctrl c
def init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)

def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer:
key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "device": Device.DEFAULT}
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
Expand All @@ -112,32 +115,36 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
seen_libs = set()

default_parallel = 1 if Device.DEFAULT == "HIP" else 0
pool = multiprocessing.Pool(multiprocessing.cpu_count()) if getenv("PARALLEL", default_parallel) else None
pool = multiprocessing.Pool(multiprocessing.cpu_count(), init_worker) if getenv("PARALLEL", default_parallel) else None

try:
var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)}
exiting, st = False, time.perf_counter()
dev = Device[Device.DEFAULT]
assert isinstance(dev, Compiled)
while not exiting:
acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin]
timed_lins: List[Tuple[Linearizer, float]] = []
for i,proc in (pool.imap_unordered(try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(try_compile_linearized_w_idx, enumerate(acted_lins))):
if proc is None: continue
lib, global_size, local_size = proc
if lib in seen_libs: continue
seen_libs.add(lib)
tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else None)
timed_lins.append((acted_lins[i], min(tms)))
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="")

# done
opts = sorted(timed_lins, key=lambda x: x[1])
exiting = len(opts) == 0 or (len(beam) > 0 and beam[0][1] <= opts[0][1])
if not exiting: beam = opts[:amt]
assert len(beam) > 0, "no BEAM items succeeded?!?"
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape())
if pool is not None: pool.close() # the pool is closed
except KeyboardInterrupt as e:
if pool is not None: pool.terminate()
raise e

var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)}
exiting, st = False, time.perf_counter()
dev = Device[Device.DEFAULT]
assert isinstance(dev, Compiled)
while not exiting:
acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin]
timed_lins: List[Tuple[Linearizer, float]] = []
for i,proc in (pool.imap_unordered(try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(try_compile_linearized_w_idx, enumerate(acted_lins))):
if proc is None: continue
lib, global_size, local_size = proc
if lib in seen_libs: continue
seen_libs.add(lib)
tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else None)
timed_lins.append((acted_lins[i], min(tms)))
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="")

# done
opts = sorted(timed_lins, key=lambda x: x[1])
exiting = len(opts) == 0 or (len(beam) > 0 and beam[0][1] <= opts[0][1])
if not exiting: beam = opts[:amt]
assert len(beam) > 0, "no BEAM items succeeded?!?"
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape())

if pool is not None: pool.close() # the pool is closed
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
if DEBUG >= 3: print(beam[0][0].applied_opts)
return beam[0][0]
Expand Down

0 comments on commit ec594cf

Please sign in to comment.