From dd9a30d3a8508282e351b5ef72dcff04fd9a8e21 Mon Sep 17 00:00:00 2001 From: Samar Khanna Date: Mon, 21 Nov 2022 22:35:21 -0800 Subject: [PATCH] Fix temporal embedding stuff --- models_vit_temporal.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/models_vit_temporal.py b/models_vit_temporal.py index 9983ac1..5589bae 100755 --- a/models_vit_temporal.py +++ b/models_vit_temporal.py @@ -50,22 +50,25 @@ def forward(self, x, timestamps, return_features=False): def forward_features(self, x, timestamps): B = x.shape[0] - x1 = self.patch_embed(x[:, 0]) - x2 = self.patch_embed(x[:, 1]) - x3 = self.patch_embed(x[:, 2]) - x = torch.cat([x1, x2, x3], dim=1) - ts_embed = torch.cat([get_1d_sincos_pos_embed_from_grid_torch(128, timestamps.reshape(-1, 3)[:, 0].float()), - get_1d_sincos_pos_embed_from_grid_torch(128, timestamps.reshape(-1, 3)[:, 1].float()), - get_1d_sincos_pos_embed_from_grid_torch(128, timestamps.reshape(-1, 3)[:, 2].float())], dim=1).float() - ts_embed = ts_embed.reshape(-1, 3, ts_embed.shape[-1]).unsqueeze(2) - ts_embed = ts_embed.expand(-1, -1, x.shape[1] // 3, -1).reshape(x.shape[0], -1, ts_embed.shape[-1]) - ts_embed = torch.cat([torch.zeros((ts_embed.shape[0], 1, ts_embed.shape[2]), device=ts_embed.device), ts_embed], dim=1) + T = x.shape[1] + + patches = [] + for t in range(T): + patches.append(self.patch_embed(x[:, t])) + x = torch.cat(patches, dim=1) # (B, T*L, D) + + ts = timestamps.view(-1, 3).float() # (B*T, 3) where 3 is for yr, mo, hr + ts_embed = torch.cat([get_1d_sincos_pos_embed_from_grid_torch(384//T, ts[:, i]) for i in range(3)]) # (B*T, 384) + ts_embed = ts_embed.view(B, T, ts_embed.shape[-1]).unsqueeze(2) # (B, T, 1, 384) + ts_embed = ts_embed.expand(-1, -1, x.shape[1] // T, -1) # (B, T, L, 384) + ts_embed = ts_embed.view(B, -1, ts_embed.shape[-1]) # (B, T*L, 384) + + pos_embed = torch.cat((self.pos_embed[:, :1, :], self.pos_embed[:, 1:, :].repeat(1, T, 1)), dim=1) # (1, T*L + 1, D-384) + total_embed = torch.cat((pos_embed.expand(B, -1, -1), ts_embed), dim=-1) # (B, T*L + 1, D) cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) - x = x + torch.cat( - [torch.cat([self.pos_embed[:, :1, :], self.pos_embed[:, 1:, :].repeat(1, 3, 1)], dim=1).expand(ts_embed.shape[0], -1, -1), - ts_embed], dim=-1) + x = x + total_embed x = self.pos_drop(x)