Skip to content

Commit

Permalink
adding the length of the output signal to irfft
Browse files Browse the repository at this point in the history
elisim authored Mar 6, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 2a4a0e1 commit fcb3eae
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion layers/AutoCorrelation.py
Original file line number Diff line number Diff line change
@@ -115,7 +115,7 @@ def forward(self, queries, keys, values, attn_mask):
q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)
k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
res = q_fft * torch.conj(k_fft)
corr = torch.fft.irfft(res, dim=-1)
corr = torch.fft.irfft(res, n=L, dim=-1)

# time delay agg
if self.training:

0 comments on commit fcb3eae

Please sign in to comment.