Skip to content

Commit

Permalink
Fit method fully tested
Browse files Browse the repository at this point in the history
  • Loading branch information
dorianb committed Nov 5, 2018
1 parent 8d31357 commit 0fb54bd
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 23 deletions.
44 changes: 29 additions & 15 deletions src/model/sequence_model/classRNN.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import tensorflow as tf
import numpy as np
from time import time
import os

from sequence_model.classSequenceModel import SequenceModel
from sequence_model.classRNNCell import RNNCell
Expand Down Expand Up @@ -56,8 +57,8 @@ def __init__(self, units, f_out, batch_size=2, time_steps=24, n_features=10, n_o
self.optimizer_name = optimizer_name
self.learning_rate = learning_rate
self.loss_name = loss_name
self.summary_path = summary_path
self.checkpoint_path = checkpoint_path
self.summary_path = os.path.join(summary_path, name)
self.checkpoint_path = os.path.join(checkpoint_path, name)
self.name = name
self.logger = logger
self.debug = debug
Expand Down Expand Up @@ -148,7 +149,7 @@ def build_model(self, input_seq):

return tf.reshape(
tf.concat(layer_outputs, axis=1),
[self.batch_size, n_cells, self.n_output]
[-1, n_cells, self.n_output]
) if self.return_sequences else prev_output

def fit(self, train_set, validation_set, initial_states, initial_outputs=None):
Expand Down Expand Up @@ -197,14 +198,21 @@ def fit(self, train_set, validation_set, initial_states, initial_outputs=None):
time0 = time()
batch_examples = train_set[i - self.batch_size: i]

feature_batch, label_batch = self.load_batch(batch_examples)
feature_batch, label_batch = SequenceModel.load_batch(batch_examples)

feed_dict = {
self.input: feature_batch,
self.label: label_batch,
self.initial_outputs: initial_outputs
self.label: label_batch
}
feed_dict.update({i: d for i, d in zip(self.initial_states, initial_states)})
feed_dict.update({
i: np.repeat(d, self.batch_size, axis=0)
for i, d in zip(self.initial_states, initial_states)
})
feed_dict.update({
self.initial_outputs: np.repeat(
initial_outputs, self.batch_size, axis=0
).reshape(initial_outputs.shape[0], self.batch_size, initial_outputs.shape[1])
}) if self.with_prev_output and initial_outputs is not None else None

_, loss_value, summaries_value, step = sess.run([
train_op, loss, summaries, self.global_step],
Expand Down Expand Up @@ -245,15 +253,21 @@ def validation_eval(self, session, summaries, dataset, initial_states, initial_o

feed_dict = {
self.input: inputs,
self.label: labels,
self.initial_outputs: initial_outputs
self.label: labels
}
feed_dict.update({i: d for i, d in zip(self.initial_states, initial_states)})

loss_value, summaries_value = session.run(
[loss, summaries],
feed_dict=feed_dict
)
feed_dict.update({
i: np.repeat(d, len(dataset), axis=0)
for i, d in zip(self.initial_states, initial_states)
})
feed_dict.update({
self.initial_outputs: np.repeat(
initial_outputs, len(dataset), axis=0
).reshape(
initial_outputs.shape[0], len(dataset), initial_outputs.shape[1]
)
}) if self.with_prev_output and initial_outputs is not None else None

loss_value, summaries_value = session.run([loss, summaries], feed_dict=feed_dict)

self.validation_writer.add_summary(summaries_value, step)

Expand Down
8 changes: 7 additions & 1 deletion src/model/sequence_model/classSequenceModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,13 @@ def compute_gradient(self, loss, global_step, max_value=1.):
"""
gvs = self.optimizer.compute_gradients(loss)
self.logger.debug(gvs) if self.logger else None
capped_gvs = [(tf.clip_by_value(grad, -max_value, max_value), var) for grad, var in gvs]
capped_gvs = [
(
tf.clip_by_value(grad, -max_value, max_value) if grad is not None else grad,
var
)
for grad, var in gvs
]
return self.optimizer.apply_gradients(capped_gvs, global_step=global_step)

def fit(self, train_set, validation_set):
Expand Down
54 changes: 47 additions & 7 deletions src/model/sequence_model/test/testRNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,20 @@ def test_fit(self):
n_features = 10
n_output = 3

train_set = np.random.rand(batch_size, time_steps, n_features)
validation_set = np.random.rand(batch_size, time_steps, n_features)
initial_states = [np.random.rand(units[l][0]) for l in range(len(units))]
n_train = 100
n_valid = 100

train_set = [
(np.random.rand(time_steps, n_features), np.ones((time_steps, n_output)) * i)
for i in range(n_train)
]

validation_set = [
(np.random.rand(time_steps, n_features), np.ones((time_steps, n_output)) * i)
for i in range(n_valid)
]

initial_states = [np.random.rand(1, units[l][0]) for l in range(len(units))]

rnn_1 = RNN(
units, f_out, batch_size=batch_size, time_steps=time_steps, n_features=n_features,
Expand All @@ -147,6 +158,16 @@ def test_fit(self):

rnn_1.fit(train_set, validation_set, initial_states)

train_set = [
(np.random.rand(time_steps, n_features), np.ones(n_output) * i)
for i in range(n_train)
]

validation_set = [
(np.random.rand(time_steps, n_features), np.ones(n_output) * i)
for i in range(n_valid)
]

rnn_2 = RNN(
units, f_out, batch_size=batch_size, time_steps=time_steps, n_features=n_features,
n_output=n_output, with_prev_output=False, with_input=True, return_sequences=False,
Expand All @@ -169,8 +190,18 @@ def test_fit(self):
[300, 150, 30, 150, 300]
]

initial_states = [np.random.rand(units[l][0]) for l in range(len(units))]
initial_outputs = np.zeros((len(units), batch_size, n_output))
train_set = [
(np.random.rand(time_steps, n_features), np.ones((time_steps, n_output)) * i)
for i in range(n_train)
]

validation_set = [
(np.random.rand(time_steps, n_features), np.ones((time_steps, n_output)) * i)
for i in range(n_valid)
]

initial_states = [np.random.rand(1, units[l][0]) for l in range(len(units))]
initial_outputs = np.zeros((len(units), n_output))

rnn_4 = RNN(
units, f_out, batch_size=batch_size, time_steps=time_steps, n_features=n_features,
Expand All @@ -179,13 +210,23 @@ def test_fit(self):

rnn_4.fit(train_set, validation_set, initial_states, initial_outputs=initial_outputs)

train_set = [
(np.random.rand(time_steps, n_features), np.ones(n_output) * i)
for i in range(n_train)
]

validation_set = [
(np.random.rand(time_steps, n_features), np.ones(n_output) * i)
for i in range(n_valid)
]

rnn_5 = RNN(
units, f_out, batch_size=batch_size, time_steps=time_steps, n_features=n_features,
n_output=n_output, with_prev_output=True, with_input=False, return_sequences=False,
summary_path = SUMMARY_PATH, checkpoint_path = CHECKPOINT_PATH, name="rnn_5"
)

rnn_5.fit(train_set, validation_set, initial_states)
rnn_5.fit(train_set, validation_set, initial_states, initial_outputs=initial_outputs)

rnn_6 = RNN(
units, f_out, batch_size=batch_size, time_steps=time_steps, n_features=n_features,
Expand All @@ -195,6 +236,5 @@ def test_fit(self):

rnn_6.fit(train_set, validation_set, initial_states)


if __name__ == '__main__':
unittest.main()

0 comments on commit 0fb54bd

Please sign in to comment.