Skip to content

Commit

Permalink
Fix temporal embedding stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
samar-khanna committed Nov 22, 2022
1 parent fd582f1 commit dd9a30d
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions models_vit_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,25 @@ def forward(self, x, timestamps, return_features=False):
def forward_features(self, x, timestamps):

B = x.shape[0]
x1 = self.patch_embed(x[:, 0])
x2 = self.patch_embed(x[:, 1])
x3 = self.patch_embed(x[:, 2])
x = torch.cat([x1, x2, x3], dim=1)
ts_embed = torch.cat([get_1d_sincos_pos_embed_from_grid_torch(128, timestamps.reshape(-1, 3)[:, 0].float()),
get_1d_sincos_pos_embed_from_grid_torch(128, timestamps.reshape(-1, 3)[:, 1].float()),
get_1d_sincos_pos_embed_from_grid_torch(128, timestamps.reshape(-1, 3)[:, 2].float())], dim=1).float()
ts_embed = ts_embed.reshape(-1, 3, ts_embed.shape[-1]).unsqueeze(2)
ts_embed = ts_embed.expand(-1, -1, x.shape[1] // 3, -1).reshape(x.shape[0], -1, ts_embed.shape[-1])
ts_embed = torch.cat([torch.zeros((ts_embed.shape[0], 1, ts_embed.shape[2]), device=ts_embed.device), ts_embed], dim=1)
T = x.shape[1]

patches = []
for t in range(T):
patches.append(self.patch_embed(x[:, t]))
x = torch.cat(patches, dim=1) # (B, T*L, D)

ts = timestamps.view(-1, 3).float() # (B*T, 3) where 3 is for yr, mo, hr
ts_embed = torch.cat([get_1d_sincos_pos_embed_from_grid_torch(384//T, ts[:, i]) for i in range(3)]) # (B*T, 384)
ts_embed = ts_embed.view(B, T, ts_embed.shape[-1]).unsqueeze(2) # (B, T, 1, 384)
ts_embed = ts_embed.expand(-1, -1, x.shape[1] // T, -1) # (B, T, L, 384)
ts_embed = ts_embed.view(B, -1, ts_embed.shape[-1]) # (B, T*L, 384)

pos_embed = torch.cat((self.pos_embed[:, :1, :], self.pos_embed[:, 1:, :].repeat(1, T, 1)), dim=1) # (1, T*L + 1, D-384)
total_embed = torch.cat((pos_embed.expand(B, -1, -1), ts_embed), dim=-1) # (B, T*L + 1, D)

cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + torch.cat(
[torch.cat([self.pos_embed[:, :1, :], self.pos_embed[:, 1:, :].repeat(1, 3, 1)], dim=1).expand(ts_embed.shape[0], -1, -1),
ts_embed], dim=-1)
x = x + total_embed

x = self.pos_drop(x)

Expand Down

0 comments on commit dd9a30d

Please sign in to comment.