Skip to content

Commit

Permalink
revise dynamic routing init value bug
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangxinyang227 committed Nov 21, 2019
1 parent 943af41 commit d515934
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions induction_network/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def __init__(self, config, vocab_size, word_vectors):
self.vocab_size = vocab_size
self.word_vectors = word_vectors

self.num_classes = self.config["num_classes"]

# [num_classes, num_support, sequence_length]
self.support = tf.placeholder(tf.int32, [None, None, None], name="support")
# [num_classes * num_queries, sequence_length]
Expand Down Expand Up @@ -86,16 +88,17 @@ def model_structure(self):
with tf.name_scope("induction_module"):

support_class = self.dynamic_routing(tf.reshape(support_final_output,
[self.config["num_classes"],
[self.num_classes,
self.config["num_support"],
-1]))

with tf.name_scope("relation_module"):
scores = self.neural_tensor_layer(support_class, queries_final_output)
self.scores = scores
self.predictions = tf.argmax(scores, axis=-1, name="predictions")

with tf.name_scope("loss"):
labels_one_hot = tf.one_hot(self.labels, self.config["num_classes"], dtype=tf.float32)
labels_one_hot = tf.one_hot(self.labels, self.num_classes, dtype=tf.float32)
losses = tf.losses.mean_squared_error(labels=labels_one_hot, predictions=scores)
l2_losses = tf.add_n(
[tf.nn.l2_loss(v)
Expand All @@ -120,12 +123,12 @@ def dynamic_routing(self, support_encoding, iter_routing=3):
:return:
"""

num_classes = self.config["num_classes"]
num_classes = self.num_classes
num_support = self.config["num_support"]
encode_size = self.config["hidden_sizes"][-1] * 2

# init dynamic routing values, weights of samples per class. [num_classes, num_support]
init_b = tf.Variable(tf.constant(0.0, dtype=tf.float32, shape=[num_classes, num_support]))
init_b = tf.constant(0.0, dtype=tf.float32, shape=[num_classes, num_support])

# transformer matrix, mapping input to another space. [encode_size, encode_size]
w_s = tf.get_variable("w_s", shape=[encode_size, encode_size], dtype=tf.float32,
Expand Down Expand Up @@ -162,7 +165,7 @@ def neural_tensor_layer(self, class_vector, query_encoder):
:param query_encoder: query set encoding matrix. [num_classes * num_queries, encode_size]
:return:
"""
num_classes = self.config["num_classes"]
num_classes = self.num_classes
encode_size = self.config["hidden_sizes"][-1] * 2
layer_size = self.config["layer_size"]

Expand Down Expand Up @@ -294,6 +297,6 @@ def infer(self, sess, batch):
self.queries: batch["queries"],
self.keep_prob: 1.0}

predict = sess.run([self.predictions], feed_dict=feed_dict)
predict, scores = sess.run([self.predictions, self.scores], feed_dict=feed_dict)

return predict
return predict, scores

0 comments on commit d515934

Please sign in to comment.