Skip to content

Commit

Permalink
Merge branch 'main' of github.com:z-x-yang/Segment-and-Track-Anything…
Browse files Browse the repository at this point in the history
… into main
  • Loading branch information
yamy-cheng committed Apr 26, 2023
2 parents 7bdc8c9 + 22444f8 commit 2dc9009
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions aot/networks/layers/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,17 @@ def __init__(self,
use_linear=False,
dropout=lt_dropout)

MultiheadLocalAttention = MultiheadLocalAttentionV2 if enable_corr else MultiheadLocalAttentionV3
# MultiheadLocalAttention = MultiheadLocalAttentionV2 if enable_corr else MultiheadLocalAttentionV3
if enable_corr:
try:
import spatial_correlation_sampler
MultiheadLocalAttention = MultiheadLocalAttentionV2
except Exception as inst:
print(inst)
print("Failed to import PyTorch Correlation, For better efficiency, please install it.")
MultiheadLocalAttention = MultiheadLocalAttentionV3
else:
MultiheadLocalAttention = MultiheadLocalAttentionV3
self.short_term_attn = MultiheadLocalAttention(d_model,
att_nhead,
dilation=local_dilation,
Expand Down Expand Up @@ -398,7 +408,17 @@ def __init__(self,
use_linear=False,
dropout=lt_dropout)

MultiheadLocalAttention = MultiheadLocalAttentionV2 if enable_corr else MultiheadLocalAttentionV3
# MultiheadLocalAttention = MultiheadLocalAttentionV2 if enable_corr else MultiheadLocalAttentionV3
if enable_corr:
try:
import spatial_correlation_sampler
MultiheadLocalAttention = MultiheadLocalAttentionV2
except Exception as inst:
print(inst)
print("Failed to import PyTorch Correlation, For better efficiency, please install it.")
MultiheadLocalAttention = MultiheadLocalAttentionV3
else:
MultiheadLocalAttention = MultiheadLocalAttentionV3
self.short_term_attn = MultiheadLocalAttention(d_model,
att_nhead,
dilation=local_dilation,
Expand Down Expand Up @@ -545,6 +565,7 @@ def __init__(self,
num_head=att_nhead,
dilation=local_dilation,
use_linear=False,
enable_corr=enable_corr,
dropout=st_dropout,
d_att=d_att,
max_dis=max_local_dis,
Expand Down

0 comments on commit 2dc9009

Please sign in to comment.