Skip to content

Commit

Permalink
HW5c fix (nit): make build_rnn more clear, remove unneeded arg
Browse files Browse the repository at this point in the history
  • Loading branch information
katerakelly committed Nov 7, 2018
1 parent 5e908c2 commit 0d5ed41
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions hw5/meta/train_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,14 @@ def build_mlp(x, output_size, scope, n_layers, size, activation=tf.tanh, output_
x = tf.layers.dense(inputs=x, units=output_size, activation=output_activation, name='fc{}'.format(i + 1), kernel_regularizer=regularizer, bias_regularizer=regularizer)
return x

def build_rnn(x, h, output_size, scope, n_layers, size, gru_size, activation=tf.tanh, output_activation=None, regularizer=None):
def build_rnn(x, h, output_size, scope, n_layers, size, activation=tf.tanh, output_activation=None, regularizer=None):
"""
builds a gated recurrent neural network
inputs are first embedded by an MLP then passed to a GRU cell
make MLP layers with `size` number of units
make the GRU with `output_size` number of units
arguments:
(see `build_policy()`)
Expand Down Expand Up @@ -96,7 +99,7 @@ def build_policy(x, h, output_size, scope, n_layers, size, gru_size, recurrent=T
"""
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
if recurrent:
x, h = build_rnn(x, h, gru_size, scope, n_layers, size, gru_size, activation=activation, output_activation=output_activation)
x, h = build_rnn(x, h, gru_size, scope, n_layers, size, activation=activation, output_activation=output_activation)
else:
x = tf.reshape(x, (-1, x.get_shape()[1]*x.get_shape()[2]))
x = build_mlp(x, gru_size, scope, n_layers + 1, size, activation=activation, output_activation=activation)
Expand All @@ -115,7 +118,7 @@ def build_critic(x, h, output_size, scope, n_layers, size, gru_size, recurrent=T
"""
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
if recurrent:
x, h = build_rnn(x, h, gru_size, scope, n_layers, size, gru_size, activation=activation, output_activation=output_activation, regularizer=regularizer)
x, h = build_rnn(x, h, gru_size, scope, n_layers, size, activation=activation, output_activation=output_activation, regularizer=regularizer)
else:
x = tf.reshape(x, (-1, x.get_shape()[1]*x.get_shape()[2]))
x = build_mlp(x, gru_size, scope, n_layers + 1, size, activation=activation, output_activation=activation, regularizer=regularizer)
Expand Down

0 comments on commit 0d5ed41

Please sign in to comment.