Skip to content

Commit

Permalink
🎨 | 🐛
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhenye-Na committed May 23, 2018
1 parent 9ad5e8b commit ecc2d4b
Showing 1 changed file with 85 additions and 22 deletions.
107 changes: 85 additions & 22 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
"""
from ops import *
from tqdm import tqdm
# import matplotlib.pyplot as plt
from torch.autograd import Variable

import torch
Expand All @@ -19,10 +17,9 @@
from torch import optim
import torch.nn.functional as F

# import matplotlib
# matplotlib.use('Agg')

# self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt


class Encoder(nn.Module):
Expand Down Expand Up @@ -55,8 +52,10 @@ def forward(self, X):
X
"""
X_tilde = Variable(X.data.new(X.size(0), self.T - 1, self.input_size).zero_())
X_encoded = Variable(X.data.new(X.size(0), self.T - 1, self.encoder_num_hidden).zero_())
X_tilde = Variable(X.data.new(
X.size(0), self.T - 1, self.input_size).zero_())
X_encoded = Variable(X.data.new(
X.size(0), self.T - 1, self.encoder_num_hidden).zero_())

# Eq. 8, parameters not in nn.Linear but to be learnt
# v_e = torch.nn.Parameter(data=torch.empty(
Expand All @@ -74,7 +73,8 @@ def forward(self, X):
s_n.repeat(self.input_size, 1, 1).permute(1, 0, 2),
X.permute(0, 2, 1)), dim=2)

x = self.encoder_attn(x.view(-1, self.encoder_num_hidden * 2 + self.T - 1))
x = self.encoder_attn(
x.view(-1, self.encoder_num_hidden * 2 + self.T - 1))

# get weights by softmax
alpha = F.softmax(x.view(-1, self.input_size))
Expand All @@ -84,7 +84,8 @@ def forward(self, X):

# encoder LSTM
self.encoder_lstm.flatten_parameters()
_, final_state = self.encoder_lstm(x_tilde.unsqueeze(0), (h_n, s_n))
_, final_state = self.encoder_lstm(
x_tilde.unsqueeze(0), (h_n, s_n))
h_n = final_state[0]
s_n = final_state[1]

Expand All @@ -104,7 +105,8 @@ def _init_states(self, X):
"""
# hidden state and cell state [num_layers*num_directions, batch_size, hidden_size]
# https://pytorch.org/docs/master/nn.html?#lstm
initial_states = Variable(X.data.new(1, X.size(0), self.encoder_num_hidden).zero_())
initial_states = Variable(X.data.new(
1, X.size(0), self.encoder_num_hidden).zero_())
return initial_states


Expand Down Expand Up @@ -173,7 +175,8 @@ def _init_states(self, X):
"""
# hidden state and cell state [num_layers*num_directions, batch_size, hidden_size]
# https://pytorch.org/docs/master/nn.html?#lstm
initial_states = Variable(X.data.new(1, X.size(0), self.decoder_num_hidden).zero_())
initial_states = Variable(X.data.new(
1, X.size(0), self.decoder_num_hidden).zero_())
return initial_states


Expand All @@ -197,6 +200,8 @@ def __init__(self, X, y, T,
self.shuffle = False
self.epochs = epochs
self.T = T
self.X = X
self.y = y

self.Encoder = Encoder(input_size=X.shape[1],
encoder_num_hidden=encoder_num_hidden,
Expand Down Expand Up @@ -226,10 +231,10 @@ def __init__(self, X, y, T,
# self.y_train = self.y_train.reshape(self.y_train.shape[0], -1)
# self.y_test = self.y_test.reshape(self.y_test.shape[0], -1)

self.total_timesteps = X.shape[0]
self.total_timesteps = self.X.shape[0]
self.train_timesteps = self.X_train.shape[0]
self.test_timesteps = self.X_test.shape[0]
self.input_size = X.shape[1]
self.input_size = self.X.shape[1]

def train(self):
"""training process."""
Expand All @@ -238,7 +243,7 @@ def train(self):
self.true = []
n_iter = 0

for epoch in tqdm(range(self.epochs)):
for epoch in range(self.epochs):
if self.shuffle:
ref_idx = np.random.permutation(self.train_timesteps - self.T)
else:
Expand All @@ -250,13 +255,13 @@ def train(self):
# get the indices of X_train
indices = ref_idx[idx:(idx + self.batch_size)]
# x = np.zeros((self.T - 1, len(indices), self.input_size))
x = np.zeros((len(indices), self.T - 1, self.input_size))
self.x = np.zeros((len(indices), self.T - 1, self.input_size))
y_prev = np.zeros((len(indices), self.T - 1))
y_gt = self.y_train[indices + self.T]

# format x into 3D tensor
for bs in range(len(indices)):
x[bs, :, :] = self.X_train[indices[bs]:(indices[bs] + self.T - 1), :]
self.x[bs, :, :] = self.X_train[indices[bs]:(indices[bs] + self.T - 1), :]
y_prev[bs, :] = self.y_train[indices[bs]:(indices[bs] + self.T - 1)]

n_iter += 1
Expand All @@ -267,7 +272,7 @@ def train(self):
self.decoder_optimizer.zero_grad()

input_weighted, input_encoded = self.Encoder(
Variable(torch.from_numpy(x).type(torch.FloatTensor)))
Variable(torch.from_numpy(self.x).type(torch.FloatTensor)))
y_pred = self.Decoder(input_encoded, Variable(
torch.from_numpy(y_prev).type(torch.FloatTensor)))

Expand All @@ -284,15 +289,30 @@ def train(self):
self.encoder_optimizer.step()
self.decoder_optimizer.step()

if n_iter % 5000 == 0 and n_iter != 0:
if n_iter % 500 == 0 and n_iter != 0:
for param_group in self.encoder_optimizer.param_groups:
param_group['lr'] = param_group['lr'] * 0.9
for param_group in self.decoder_optimizer.param_groups:
param_group['lr'] = param_group['lr'] * 0.9

if n_iter % 100 == 0:
if n_iter % 10 == 0:
print("Iterations: ", n_iter, "\tLoss: ", self.loss[-1])

if epoch % 2 == 0:
y_train_pred = self.test(on_train=True)
# y_test_pred = self.test(on_train=False)
y_pred = np.concatenate((y_train_pred, y_test_pred))
plt.figure()
plt.plot(range(1, 1 + len(self.y_train)),
self.y_train, label="True")
plt.plot(range(self.T, len(y_train_pred) + self.T),
y_train_pred, label='Predicted - Train')
plt.plot(range(self.T + len(y_test_pred), len(self.y_test) + 1),
y_test_pred, label='Predicted - Test')
plt.legend(loc='upper left')
plt.savefig("plot_" + str(epoch) + ".png")
plt.close(fig)

# Save files in last iterations
if epoch == self.epochs - 1:
np.savetxt('../loss.txt', np.array(self.loss), delimiter=',')
Expand All @@ -305,6 +325,49 @@ def val(self):
"""validation."""
pass

def test(self):
def test(self, on_train=False):
"""test."""
pass
if on_train:
y_pred = np.zeros(self.train_timesteps - self.T + 1)
else:
y_pred = np.zeros(self.X_test.shape[0] - self.test_timesteps)

i = 0
while i < len(y_pred):
batch_idx = np.array(range(len(y_pred)))[i: (i + self.batch_size)]

if on_train:
X = np.zeros((len(batch_idx), self.T - 1, self.X_train.shape[1]))
else:
X = np.zeros((len(batch_idx), self.T - 1, self.X_test.shape[1]))

y_history = np.zeros((len(batch_idx), self.T - 1))
for j in range(len(batch_idx)):
if on_train:
X[j, :, :] = self.X_train[range(
batch_idx[j], batch_idx[j] + self.T - 1), :]
y_history[j, :] = self.y_train[range(
batch_idx[j], batch_idx[j] + self.T - 1)]
else:
X[j, :, :] = self.X_test[range(
batch_idx[j] + self.test_timesteps - self.T, batch_idx[j] + self.test_timesteps - 1), :]
y_history[j, :] = self.y_test[range(
batch_idx[j] + self.test_timesteps - self.T, batch_idx[j] + self.test_timesteps - 1)]

if on_train:
y_history = Variable(torch.from_numpy(
y_history).type(torch.FloatTensor))
_, input_encoded = self.Encoder(
Variable(torch.from_numpy(self.X_train).type(torch.FloatTensor)))
y_pred[i:(i + self.batch_size)] = self.Decoder(input_encoded,
y_history).cpu().data.numpy()[:, 0]
i += self.batch_size
else:
y_history = Variable(torch.from_numpy(
y_history).type(torch.FloatTensor))
_, input_encoded = self.Encoder(
Variable(torch.from_numpy(self.X_test).type(torch.FloatTensor)))
y_pred[i:(i + self.batch_size)] = self.Decoder(input_encoded,
y_history).cpu().data.numpy()[:, 0]
i += self.batch_size
return y_pred

0 comments on commit ecc2d4b

Please sign in to comment.