Skip to content

Commit

Permalink
fixes and refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
axeloh committed Jun 18, 2021
1 parent dd28fe3 commit 8edf020
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
4 changes: 2 additions & 2 deletions bash_scripts/train_smd_2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ python train.py --dataset SMD --group 3-6 --epochs 10 --use_gatv2 False --feat_g
python train.py --dataset SMD --group 3-7 --epochs 10 --use_gatv2 False --feat_gat_embed_dim 32 --time_gat_embed_dim 8
python train.py --dataset SMD --group 3-8 --epochs 10 --use_gatv2 False --feat_gat_embed_dim 32 --time_gat_embed_dim 8
python train.py --dataset SMD --group 3-9 --epochs 10 --use_gatv2 False --feat_gat_embed_dim 32 --time_gat_embed_dim 8
python train.py --dataset SMD --group 3-10--epochs 10 --use_gatv2 False --feat_gat_embed_dim 32 --time_gat_embed_dim 8
python train.py --dataset SMD --group 3-11--epochs 10 --use_gatv2 False --feat_gat_embed_dim 32 --time_gat_embed_dim 8
python train.py --dataset SMD --group 3-10 --epochs 10 --use_gatv2 False --feat_gat_embed_dim 32 --time_gat_embed_dim 8
python train.py --dataset SMD --group 3-11 --epochs 10 --use_gatv2 False --feat_gat_embed_dim 32 --time_gat_embed_dim 8
6 changes: 4 additions & 2 deletions modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,12 @@ def _make_attention_input(self, v):
K = self.num_nodes
blocks_repeating = v.repeat_interleave(K, dim=1) # Left-side of the matrix
blocks_alternating = v.repeat(1, K, 1) # Right-side of the matrix

combined = torch.cat((blocks_repeating, blocks_alternating), dim=2) # (b, K*K, 2*window_size)

return combined.view(v.size(0), K, K, 2 * self.window_size)
if self.use_gatv2:
return combined.view(v.size(0), K, K, 2 * self.window_size)
else:
return combined.view(v.size(0), K, K, 2 * self.embed_dim)


class TemporalAttentionLayer(nn.Module):
Expand Down

0 comments on commit 8edf020

Please sign in to comment.