Skip to content

Commit

Permalink
fix potential bug
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 29, 2021
1 parent ef34b82 commit b5bb4f8
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion h_transformer_1d/h_transformer_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def forward(self, x, mask = None):
def calculate_Y_and_A(q, k, v, mask_A = False, remove_right_off_diagonals = False):
if remove_right_off_diagonals:
q, k, v = map(lambda t: rearrange(t, 'b (n r) z d -> b n r z d', r = 2), (q, k, v))
q, k, v = map(lambda t: t[:, :, 1], (q, k, v))
q, k, v = map(lambda t: t[:, :, 0], (q, k, v))

S = einsum('... i d, ... j d -> ... i j', q, k)

Expand Down Expand Up @@ -140,6 +140,10 @@ def calculate_Y_and_A(q, k, v, mask_A = False, remove_right_off_diagonals = Fals
k = torch.flip(k, dims = (2,)) # so we pay attention to the off-diagonal blocks in the attention matrix
k = rearrange(k, 'b n r z d -> b (n r) z d')

v = rearrange(v, 'b (n r) z d -> b n r z d', r = 2)
v = torch.flip(v, dims = (2,))
v = rearrange(v, 'b n r z d -> b (n r) z d')

coarsened_Y = calculate_Y_and_A(q, k, v, remove_right_off_diagonals = causal)
Ys.append(coarsened_Y)

Expand Down

0 comments on commit b5bb4f8

Please sign in to comment.