Skip to content

Commit

Permalink
Add some comments to clarify
Browse files Browse the repository at this point in the history
  • Loading branch information
imanmousaei committed Mar 25, 2023
1 parent 07e7608 commit d0f711b
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 1 deletion.
Binary file not shown.
9 changes: 9 additions & 0 deletions exp/exp_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,15 @@ def train(self, setting):
iter_count += 1
model_optim.zero_grad()

"""
# padding_mask is a binary matrix of shape (batch_size, sequence_length)
# where 0 denotes a padding element and 1 denotes a non-padding element.
"""
# print(f'train: padding_mask = {padding_mask}') # all True
print(f'train: batch_x.shape = {batch_x.shape}, label.shape = {label.shape}') # all True
# train: batch_x.shape = torch.Size([16, 152, 3]), label.shape = torch.Size([16, 1])
# N refers to the number of features = 3

batch_x = batch_x.float().to(self.device)
padding_mask = padding_mask.float().to(self.device)
label = label.to(self.device)
Expand Down
5 changes: 5 additions & 0 deletions exp/exp_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def vali(self, vali_data, vali_loader, criterion):

# random mask
B, T, N = batch_x.shape
"""
B = batch size
T = seq len
N = number of features
"""
mask = torch.rand((B, T, N)).to(self.device)
mask[mask <= self.args.mask_rate] = 0 # masked
mask[mask > self.args.mask_rate] = 1 # remained
Expand Down
3 changes: 3 additions & 0 deletions layers/AutoCorrelation.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def forward(self, queries, keys, values, attn_mask):
_, S, _ = keys.shape
H = self.n_heads

print(f'AutoCorrelationLayer: B = {B}, L = {L}, S = {S}, H = {H}')
# B = 16, L = 152, S = 152, H = 8

queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)
Expand Down
23 changes: 22 additions & 1 deletion models/FEDformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,18 +147,39 @@ def anomaly_detection(self, x_enc):
dec_out = self.projection(enc_out)
return dec_out

# x_enc is a tensor representing the input sequence of the model.
def classification(self, x_enc, x_mark_enc):
"""
Args:
x_enc: a tensor representing the input sequence of the model
x_mark_enc: padding_mask: This sequence is used to indicate which elements of the input sequence are actual data and which are padding
True: data, False: padding
"""

# enc
enc_out = self.enc_embedding(x_enc, None) # map each element of the input sequence to a continuous vector space
enc_out, attns = self.encoder(enc_out, attn_mask=None) # to capture the relationships between the elements in the sequence.

# batch size = 16, seq_length = 152, channels = 128
print(f'enc_out.shape = {enc_out.shape}, len(attns) = {len(attns)}')
# enc_out.shape = torch.Size([16, 152, 128]), len(attns) = 3

# Output
output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity
output = self.dropout(output)
print(f'output1.shape = {output.shape}')
# torch.Size([16, 152, 128])

output = output * x_mark_enc.unsqueeze(-1) # zero-out padding embeddings
print(f'output2.shape = {output.shape}')
# torch.Size([16, 152, 128])

output = output.reshape(output.shape[0], -1) # (batch_size, seq_length * d_model)
print(f'output3.shape = {output.shape}')
# output3.shape = torch.Size([16, 19456]) , 19456 = 152*128

output = self.projection(output) # (batch_size, num_classes)
print(f'output4.shape = {output.shape}')
# output4.shape = torch.Size([16, 26])
return output

def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
Expand Down

0 comments on commit d0f711b

Please sign in to comment.