Skip to content

Commit

Permalink
Uploaded material for lecture 14 on triton
Browse files Browse the repository at this point in the history
  • Loading branch information
UmerHA committed Apr 13, 2024
1 parent f6884e7 commit 5696e30
Show file tree
Hide file tree
Showing 2 changed files with 2,553 additions and 0 deletions.
2,497 changes: 2,497 additions & 0 deletions lecture 14/A_Practitioners_Guide_to_Triton.ipynb

Large diffs are not rendered by default.

56 changes: 56 additions & 0 deletions lecture 14/triton_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
import triton
import triton.language as tl

def test_pid_conds(conds, pid_0=[0], pid_1=[0], pid_2=[0]):
'''Test if condition on pids are fulfilled
E.g.:
'=0' checks that pid_0 == 0
',>1' checks that pid_1 > 1
'>1,=0' checks that pid_0 > 1 and pid_1 == 0
'''
pids = pid_0[0], pid_1[0], pid_2[0]
conds = conds.replace(' ','').split(',')
for i, (cond, pid) in enumerate(zip(conds, pids)):
if cond=='': continue
op, threshold = cond[0], int(cond[1:])
if op not in ['<','>','>=','<=','=', '!=']: raise ValueError(f"Rules may only use these ops: '<','>','>=','<=','=', '!='. Invalid rule: '{condition}'.")
op = '==' if op == '=' else op
if not eval(f'{pid} {op} {threshold}'): return False
return True

assert test_pid_conds('')
assert test_pid_conds('>0', [1], [1])
assert not test_pid_conds('>0', [0], [1])
assert test_pid_conds('=0,=1', [0], [1], [0])

def breakpoint_if(conds, pid_0=[0], pid_1=[0], pid_2=[0]):
'''Stop kernel, if any condition of pids is fulfilled'''
from IPython.core.debugger import set_trace

if test_pid_conds(conds, pid_0, pid_1, pid_2): set_trace()

def print_if(txt, conds, pid_0=[0], pid_1=[0], pid_2=[0]):
'''Print txt, if any condition of pids is fulfilled'''
if test_pid_conds(conds, pid_0, pid_1, pid_2): print(txt)

def check_tensors_gpu_ready(*tensors):
for t in tensors:
assert t.is_contiguous, "A tensor is not contiguous"
if not os.environ.get('TRITON_INTERPRET') == '1': assert t.is_cuda, "A tensor is not on cuda"

def cdiv(a,b): return (a + b - 1) // b
assert cdiv(10,2)==5
assert cdiv(10,3)==4

@triton.jit
def get_1d_offest(size, n_prev_chunks): return n_prev_chunks * size + tl.arange(0, size)

@triton.jit
def get_2d_offset(offs_0, offs_1, stride_0, stride_1=1): return tl.expand_dims(offs_0, 1)*stride_0 + tl.expand_dims(offs_1, 0)*stride_1

@triton.jit
def get_1d_mask(offs, max): return offs < max

@triton.jit
def get_2d_mask(offs_0, offs_1, max_0, max_1): return (tl.expand_dims(offs_0, 1) < max_0) & (tl.expand_dims(offs_1, 0) < max_1)

0 comments on commit 5696e30

Please sign in to comment.