Skip to content

Commit

Permalink
finish session
Browse files Browse the repository at this point in the history
  • Loading branch information
ShubhangDesai committed Nov 13, 2016
1 parent b4de810 commit 6ec0c51
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
8 changes: 4 additions & 4 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 9 additions & 6 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ def next(self):
ox = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))
om = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))
ob = tf.Variable(tf.zeros([num_nodes]))
# Classifier weights and biases
w = tf.Variable(tf.random_uniform([num_nodes, num_nodes]))
b = tf.Variable(tf.zeros([num_nodes, 1]))

def lstm_cell(i, o, state):
input_gate = tf.sigmoid(tf.matmul(ix, i) + tf.matmul(im, o) + ib)
Expand All @@ -81,17 +78,17 @@ def lstm_cell(i, o, state):
output = tf.zeros([num_nodes, 1])
state = tf.zeros([num_nodes, 1])
i = 0
MSE = 0
while i < (days-1):
output, state = lstm_cell(tf.reshape(x[:, i], [5, 1]), output, state)
y.append(output)
MSE += tf.reduce_mean(tf.square(output - tf.reshape(x[:, i+1], [5, 1])))
i += 1
x = x[:, 1:]

logits = tf.matmul(w, tf.reshape(tf.reduce_sum(tf.concat(0, y), 0), [5, 1])) + b
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits, tf.reshape(tf.reduce_sum(tf.concat(0, x), 1), [5, 1])))
loss = MSE/(days-1)
optimizer = tf.train.AdamOptimizer().minimize(loss)
yT, _ = lstm_cell(tf.reshape(x[:, i], [5, 1]), output, state)
prediction = tf.matmul(w, yT) + b

num_steps = 2000
summary_frequency = 100
Expand All @@ -101,11 +98,17 @@ def lstm_cell(i, o, state):
mean_loss = 0
for step in range(num_steps):
batch = train_batches.next()
batch = batch.astype(np.float32)
print(batch.shape)
print(batch.dtype)
_, l = sess.run([optimizer, loss], feed_dict={x: batch})
mean_loss += l
if (step+1) % summary_frequency == 0:
mean_loss = mean_loss/summary_frequency
print('Average MSE at step %f: %s' % (step, mean_loss))
valid_batch = valid_batches.next()
valid_batch = batch.astype(np.float32)
v_l = sess.run([loss], feed_dict={x: valid_batch})
print('Validation MSE: %f' % v_l)

save_path = tf.train.Saver().save(sess, '../models/' + stock + '_model.ckpt')

0 comments on commit 6ec0c51

Please sign in to comment.