Skip to content

Commit

Permalink
tf.concat arguments reordered in time_distributed and some examples (t…
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruslanmlnkv authored and aymericdamien committed Apr 13, 2017
1 parent 4f4424c commit 3c75399
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion examples/basics/weights_loading_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(self):
with tf.variable_scope("scope2") as scope:
net_dnn = Model2.make_core_network(inputs) # shape (?, 10)

network = tf.concat(1, [net_conv, net_dnn], name="concat") # shape (?, 20)
network = tf.concat([net_conv, net_dnn], 1, name="concat") # shape (?, 20)
network = tflearn.fully_connected(network, 10, activation="softmax")
network = regression(network, optimizer='adam', learning_rate=0.01,
loss='categorical_crossentropy', name='target')
Expand Down
2 changes: 1 addition & 1 deletion examples/others/recommender_wide_and_deep.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def deep_model(self, wide_inputs, n_inputs, n_nodes=[100, 50], use_dropout=False
print (" %s_embed = %s" % (cc, cc_embed_var[cc]))
flat_vars.append(tf.squeeze(cc_embed_var[cc], squeeze_dims=[1], name="%s_squeeze" % cc))

network = tf.concat(1, [wide_inputs] + flat_vars, name="deep_concat")
network = tf.concat([wide_inputs] + flat_vars, 1, name="deep_concat")
for k in range(len(n_nodes)):
network = tflearn.fully_connected(network, n_nodes[k], activation="relu", name="deep_fc%d" % (k+1))
if use_dropout:
Expand Down
2 changes: 1 addition & 1 deletion tflearn/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,4 +644,4 @@ def time_distributed(incoming, fn, args=None, scope=None):
x = map(lambda t: tf.reshape(t, [-1, 1]+utils.get_incoming_shape(t)[1:]), x)
except:
x = list(map(lambda t: tf.reshape(t, [-1, 1]+utils.get_incoming_shape(t)[1:]), x))
return tf.concat(1, x)
return tf.concat(x, 1)

0 comments on commit 3c75399

Please sign in to comment.