Skip to content

Commit

Permalink
Fixed subtle bug in triton_util.py
Browse files Browse the repository at this point in the history
UmerHA authored May 12, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent dc0af4f commit 1ed758a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lecture_014/triton_util.py
Original file line number Diff line number Diff line change
@@ -36,7 +36,7 @@ def print_if(txt, conds, pid_0=[0], pid_1=[0], pid_2=[0]):

def check_tensors_gpu_ready(*tensors):
for t in tensors:
assert t.is_contiguous, "A tensor is not contiguous"
assert t.is_contiguous(), "A tensor is not contiguous"
if not os.environ.get('TRITON_INTERPRET') == '1': assert t.is_cuda, "A tensor is not on cuda"

def cdiv(a,b): return (a + b - 1) // b
@@ -53,4 +53,4 @@ def get_2d_offset(offs_0, offs_1, stride_0, stride_1=1): return tl.expand_dims(
def get_1d_mask(offs, max): return offs < max

@triton.jit
def get_2d_mask(offs_0, offs_1, max_0, max_1): return (tl.expand_dims(offs_0, 1) < max_0) & (tl.expand_dims(offs_1, 0) < max_1)
def get_2d_mask(offs_0, offs_1, max_0, max_1): return (tl.expand_dims(offs_0, 1) < max_0) & (tl.expand_dims(offs_1, 0) < max_1)

0 comments on commit 1ed758a

Please sign in to comment.