Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
wuhaixu2016 committed Mar 25, 2023
1 parent 38ac444 commit 4a0826a
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 61 deletions.
18 changes: 1 addition & 17 deletions data_provider/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,6 @@
'UEA': UEAloader
}

def save_seq_len(seq_len):
with open('seq_len', 'w') as f:
f.write(str(seq_len))

def load_seq_len():
with open('seq_len', 'r') as f:
return int(f.read())


def collate_fn_local(x):
seq_len=load_seq_len()
return collate_fn(x, max_len=seq_len)


def data_provider(args, flag):
Data = data_dict[args.data]
Expand Down Expand Up @@ -71,17 +58,14 @@ def data_provider(args, flag):
root_path=args.root_path,
flag=flag,
)
print(f'.......creating dataloader........', flag, len(data_set), args.data)

save_seq_len(args.seq_len)

data_loader = DataLoader(
data_set,
batch_size=batch_size,
shuffle=shuffle_flag,
num_workers=args.num_workers,
drop_last=drop_last,
collate_fn=collate_fn_local
collate_fn=lambda x: collate_fn(x, max_len=args.seq_len)
)
return data_set, data_loader
else:
Expand Down
11 changes: 1 addition & 10 deletions exp/exp_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,6 @@ def train(self, setting):
iter_count += 1
model_optim.zero_grad()

"""
# padding_mask is a binary matrix of shape (batch_size, sequence_length)
# where 0 denotes a padding element and 1 denotes a non-padding element.
"""
# print(f'train: padding_mask = {padding_mask}') # all True
print(f'train: batch_x.shape = {batch_x.shape}, label.shape = {label.shape}') # all True
# train: batch_x.shape = torch.Size([16, 152, 3]), label.shape = torch.Size([16, 1])
# N refers to the number of features = 3

batch_x = batch_x.float().to(self.device)
padding_mask = padding_mask.float().to(self.device)
label = label.to(self.device)
Expand Down Expand Up @@ -139,7 +130,7 @@ def train(self, setting):

print(
"Epoch: {0}, Steps: {1} | Train Loss: {2:.3f} Vali Loss: {3:.3f} Vali Acc: {4:.3f} Test Loss: {5:.3f} Test Acc: {6:.3f}"
.format(epoch + 1, train_steps, train_loss, vali_loss, val_accuracy, test_loss, test_accuracy))
.format(epoch + 1, train_steps, train_loss, vali_loss, val_accuracy, test_loss, test_accuracy))
early_stopping(-val_accuracy, self.model, path)
if early_stopping.early_stop:
print("Early stopping")
Expand Down
3 changes: 0 additions & 3 deletions layers/AutoCorrelation.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,6 @@ def forward(self, queries, keys, values, attn_mask):
_, S, _ = keys.shape
H = self.n_heads

print(f'AutoCorrelationLayer: B = {B}, L = {L}, S = {S}, H = {H}')
# B = 16, L = 152, S = 152, H = 8

queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)
Expand Down
38 changes: 8 additions & 30 deletions models/FEDformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self, configs, version='fourier', mode_select='random', modes=32):
[
EncoderLayer(
AutoCorrelationLayer(
encoder_self_att, # instead of multi-head attention in transformer
encoder_self_att, # instead of multi-head attention in transformer
configs.d_model, configs.n_heads),
configs.d_model,
configs.d_ff,
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(self, configs, version='fourier', mode_select='random', modes=32):
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
# decomp init
mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1)
seasonal_init, trend_init = self.decomp(x_enc) # x - moving_avg, moving_avg
seasonal_init, trend_init = self.decomp(x_enc) # x - moving_avg, moving_avg
# decoder input
trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1)
seasonal_init = F.pad(seasonal_init[:, -self.label_len:, :], (0, 0, 0, self.pred_len))
Expand Down Expand Up @@ -148,38 +148,16 @@ def anomaly_detection(self, x_enc):
return dec_out

def classification(self, x_enc, x_mark_enc):
"""
Args:
x_enc: a tensor representing the input sequence of the model
x_mark_enc: padding_mask: This sequence is used to indicate which elements of the input sequence are actual data and which are padding
True: data, False: padding
"""

# enc
enc_out = self.enc_embedding(x_enc, None) # map each element of the input sequence to a continuous vector space
enc_out, attns = self.encoder(enc_out, attn_mask=None) # to capture the relationships between the elements in the sequence.

# batch size = 16, seq_length = 152, channels = 128
print(f'enc_out.shape = {enc_out.shape}, len(attns) = {len(attns)}')
# enc_out.shape = torch.Size([16, 152, 128]), len(attns) = 3
enc_out = self.enc_embedding(x_enc, None)
enc_out, attns = self.encoder(enc_out, attn_mask=None)

# Output
output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity
output = self.act(enc_out)
output = self.dropout(output)
print(f'output1.shape = {output.shape}')
# torch.Size([16, 152, 128])

output = output * x_mark_enc.unsqueeze(-1) # zero-out padding embeddings
print(f'output2.shape = {output.shape}')
# torch.Size([16, 152, 128])

output = output.reshape(output.shape[0], -1) # (batch_size, seq_length * d_model)
print(f'output3.shape = {output.shape}')
# output3.shape = torch.Size([16, 19456]) , 19456 = 152*128

output = self.projection(output) # (batch_size, num_classes)
print(f'output4.shape = {output.shape}')
# output4.shape = torch.Size([16, 26])
output = output * x_mark_enc.unsqueeze(-1)
output = output.reshape(output.shape[0], -1)
output = self.projection(output)
return output

def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
einops==0.6.0
einops==0.4.0
matplotlib==3.7.0
numpy==1.23.5
pandas==1.5.3
Expand Down

0 comments on commit 4a0826a

Please sign in to comment.