diff --git a/main.py b/main.py index 00bc913..952ba02 100644 --- a/main.py +++ b/main.py @@ -55,7 +55,15 @@ def get_output_shape_for(self, input_shape): return input_shape[:self.axis] + input_shape[self.axis+1:] def get_output_for(self, input, **kwargs): - return T.sum(input, axis=self.axis, dtype=theano.config.floatX) + return T.sum(input, axis=self.axis) + +class TemporalSumLayer(SumLayer): + def __init__(self, incoming, axis, T=lasagne.init.Normal(std=0.1), **kwargs): + super(TemporalSumLayer, self).__init__(incoming, axis, **kwargs) + self.T = self.add_param(T, (self.input_shape[axis-1], self.input_shape[axis+1]), name="T") + + def get_output_for(self, input, **kwargs): + return T.sum(input, axis=self.axis) + self.T class TransposedDenseLayer(lasagne.layers.DenseLayer): def __init__(self, incoming, num_units, W=lasagne.init.GlorotUniform(), @@ -76,7 +84,7 @@ def get_output_for(self, input, **kwargs): return self.nonlinearity(activation) class MemoryNetworkLayer(lasagne.layers.MergeLayer): - def __init__(self, incomings, vocab, embedding_size, A=lasagne.init.Normal(std=0.1), C=lasagne.init.Normal(std=0.1), **kwargs): + def __init__(self, incomings, vocab, embedding_size, A, A_T, C, C_T, **kwargs): super(MemoryNetworkLayer, self).__init__(incomings, **kwargs) if len(incomings) != 2: raise NotImplementedError @@ -90,12 +98,14 @@ def __init__(self, incomings, vocab, embedding_size, A=lasagne.init.Normal(std=0 l_A_embedding = lasagne.layers.EmbeddingLayer(l_context_in, len(vocab)+1, embedding_size, W=A) self.A = l_A_embedding.W l_A_embedding = lasagne.layers.ReshapeLayer(l_A_embedding, shape=(batch_size, max_seqlen, max_sentlen, embedding_size)) - l_A_embedding = SumLayer(l_A_embedding, axis=2) + l_A_embedding = TemporalSumLayer(l_A_embedding, axis=2, T=A_T) + self.A_T = l_A_embedding.T l_C_embedding = lasagne.layers.EmbeddingLayer(l_context_in, len(vocab)+1, embedding_size, W=C) self.C = l_C_embedding.W l_C_embedding = lasagne.layers.ReshapeLayer(l_C_embedding, shape=(batch_size, max_seqlen, max_sentlen, embedding_size)) - l_C_embedding = SumLayer(l_C_embedding, axis=2) + l_C_embedding = TemporalSumLayer(l_C_embedding, axis=2, T=C_T) + self.C_T = l_C_embedding.T l_prob = InnerProductLayer((l_A_embedding, l_B_embedding), nonlinearity=lasagne.nonlinearities.softmax) l_weighted_output = BatchedDotLayer((l_prob, l_C_embedding)) @@ -176,6 +186,7 @@ def __init__(self, train_file, test_file, batch_size=32, embedding_size=20, max_ l_question_in = lasagne.layers.InputLayer(shape=(batch_size, max_sentlen)) A, C = lasagne.init.Normal(std=0.1).sample((len(vocab)+1, embedding_size)), lasagne.init.Normal(std=0.1) + A_T, C_T = lasagne.init.Normal(std=0.1), lasagne.init.Normal(std=0.1) W = A if adj_weight_tying else lasagne.init.Normal(std=0.1) l_question_in = lasagne.layers.ReshapeLayer(l_question_in, shape=(batch_size * max_sentlen, )) @@ -184,13 +195,15 @@ def __init__(self, train_file, test_file, batch_size=32, embedding_size=20, max_ l_B_embedding = lasagne.layers.ReshapeLayer(l_B_embedding, shape=(batch_size, max_sentlen, embedding_size)) l_B_embedding = SumLayer(l_B_embedding, axis=1) - self.mem_layers = [MemoryNetworkLayer((l_context_in, l_B_embedding), vocab, embedding_size, A=A, C=C)] + self.mem_layers = [MemoryNetworkLayer((l_context_in, l_B_embedding), vocab, embedding_size, A=A, A_T=A_T, C=C, C_T=C_T)] for _ in range(1, num_hops): if adj_weight_tying: A, C = self.mem_layers[-1].C, lasagne.init.Normal(std=0.1) + A_T, C_T = self.mem_layers[-1].C_T, lasagne.init.Normal(std=0.1) else: # RNN style A, C = self.mem_layers[-1].A, self.mem_layers[-1].C - self.mem_layers += [MemoryNetworkLayer((l_context_in, self.mem_layers[-1]), vocab, embedding_size, A=A, C=C)] + A_T, C_T = self.mem_layers[-1].A_T, self.mem_layers[-1].C_T + self.mem_layers += [MemoryNetworkLayer((l_context_in, self.mem_layers[-1]), vocab, embedding_size, A=A, A_T=A_T, C=C, C_T=C_T)] if adj_weight_tying: l_pred = TransposedDenseLayer(self.mem_layers[-1], self.num_classes, W=self.mem_layers[-1].C, b=None, nonlinearity=lasagne.nonlinearities.softmax) @@ -350,7 +363,7 @@ def process_dataset(self, lines, word_to_idx, max_sentlen, offset): S.append(word_indices) if line['type'] == 'q': id = line['id']-1 - indices = [offset+idx for idx in range(i-id, i) if lines[idx]['type'] == 's'] + indices = [offset+idx for idx in range(i-id, i) if lines[idx]['type'] == 's'][::-1] line['refs'] = [indices.index(offset+i-id+ref) for ref in line['refs']] C.append(indices) Q.append(offset+i)