Skip to content

Commit

Permalink
fix enable corr for deaot
Browse files Browse the repository at this point in the history
  • Loading branch information
yoxu515 committed Apr 27, 2023
1 parent 6eeb579 commit 6497cce
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion aot/networks/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,7 @@ def forward(self, q, k, v, u, size_2d):
n, self.num_head, self.window_size * self.window_size, h * w)
else:
unfolded_k = self.pad_and_unfold(k).view(
n * self.num_head, hidden_dim,
n * self.num_head, self.d_att,
self.window_size * self.window_size, h, w)
qk = (q.unsqueeze(2) * unfolded_k).sum(dim=1).view(
n, self.num_head, self.window_size * self.window_size, h * w)
Expand Down
7 changes: 7 additions & 0 deletions aot/networks/layers/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,13 @@ def __init__(self,
top_k=-1,
expand_ratio=expand_ratio)

if enable_corr:
try:
import spatial_correlation_sampler
except Exception as inst:
print(inst)
print("Failed to import PyTorch Correlation, For better efficiency, please install it.")
enable_corr = False
self.short_term_attn = LocalGatedPropagation(d_qk=self.d_model,
d_vu=self.d_model * 2,
num_head=att_nhead,
Expand Down

0 comments on commit 6497cce

Please sign in to comment.