diff --git a/aot/networks/layers/attention.py b/aot/networks/layers/attention.py index 83f107c2..2bd2598a 100644 --- a/aot/networks/layers/attention.py +++ b/aot/networks/layers/attention.py @@ -821,7 +821,7 @@ def forward(self, q, k, v, u, size_2d): n, self.num_head, self.window_size * self.window_size, h * w) else: unfolded_k = self.pad_and_unfold(k).view( - n * self.num_head, hidden_dim, + n * self.num_head, self.d_att, self.window_size * self.window_size, h, w) qk = (q.unsqueeze(2) * unfolded_k).sum(dim=1).view( n, self.num_head, self.window_size * self.window_size, h * w) diff --git a/aot/networks/layers/transformer.py b/aot/networks/layers/transformer.py index 211cd1c5..5f20ab82 100755 --- a/aot/networks/layers/transformer.py +++ b/aot/networks/layers/transformer.py @@ -560,6 +560,13 @@ def __init__(self, top_k=-1, expand_ratio=expand_ratio) + if enable_corr: + try: + import spatial_correlation_sampler + except Exception as inst: + print(inst) + print("Failed to import PyTorch Correlation, For better efficiency, please install it.") + enable_corr = False self.short_term_attn = LocalGatedPropagation(d_qk=self.d_model, d_vu=self.d_model * 2, num_head=att_nhead,