forked from gpu-mode/lectures
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Uploaded material for lecture 14 on triton
- Loading branch information
Showing
2 changed files
with
2,553 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import os | ||
import triton | ||
import triton.language as tl | ||
|
||
def test_pid_conds(conds, pid_0=[0], pid_1=[0], pid_2=[0]): | ||
'''Test if condition on pids are fulfilled | ||
E.g.: | ||
'=0' checks that pid_0 == 0 | ||
',>1' checks that pid_1 > 1 | ||
'>1,=0' checks that pid_0 > 1 and pid_1 == 0 | ||
''' | ||
pids = pid_0[0], pid_1[0], pid_2[0] | ||
conds = conds.replace(' ','').split(',') | ||
for i, (cond, pid) in enumerate(zip(conds, pids)): | ||
if cond=='': continue | ||
op, threshold = cond[0], int(cond[1:]) | ||
if op not in ['<','>','>=','<=','=', '!=']: raise ValueError(f"Rules may only use these ops: '<','>','>=','<=','=', '!='. Invalid rule: '{condition}'.") | ||
op = '==' if op == '=' else op | ||
if not eval(f'{pid} {op} {threshold}'): return False | ||
return True | ||
|
||
assert test_pid_conds('') | ||
assert test_pid_conds('>0', [1], [1]) | ||
assert not test_pid_conds('>0', [0], [1]) | ||
assert test_pid_conds('=0,=1', [0], [1], [0]) | ||
|
||
def breakpoint_if(conds, pid_0=[0], pid_1=[0], pid_2=[0]): | ||
'''Stop kernel, if any condition of pids is fulfilled''' | ||
from IPython.core.debugger import set_trace | ||
|
||
if test_pid_conds(conds, pid_0, pid_1, pid_2): set_trace() | ||
|
||
def print_if(txt, conds, pid_0=[0], pid_1=[0], pid_2=[0]): | ||
'''Print txt, if any condition of pids is fulfilled''' | ||
if test_pid_conds(conds, pid_0, pid_1, pid_2): print(txt) | ||
|
||
def check_tensors_gpu_ready(*tensors): | ||
for t in tensors: | ||
assert t.is_contiguous, "A tensor is not contiguous" | ||
if not os.environ.get('TRITON_INTERPRET') == '1': assert t.is_cuda, "A tensor is not on cuda" | ||
|
||
def cdiv(a,b): return (a + b - 1) // b | ||
assert cdiv(10,2)==5 | ||
assert cdiv(10,3)==4 | ||
|
||
@triton.jit | ||
def get_1d_offest(size, n_prev_chunks): return n_prev_chunks * size + tl.arange(0, size) | ||
|
||
@triton.jit | ||
def get_2d_offset(offs_0, offs_1, stride_0, stride_1=1): return tl.expand_dims(offs_0, 1)*stride_0 + tl.expand_dims(offs_1, 0)*stride_1 | ||
|
||
@triton.jit | ||
def get_1d_mask(offs, max): return offs < max | ||
|
||
@triton.jit | ||
def get_2d_mask(offs_0, offs_1, max_0, max_1): return (tl.expand_dims(offs_0, 1) < max_0) & (tl.expand_dims(offs_1, 0) < max_1) |