Skip to content

Commit

Permalink
enable dict input; bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
farizrahman4u committed Nov 5, 2016
1 parent f96a0ea commit 66e4518
Showing 1 changed file with 136 additions and 15 deletions.
151 changes: 136 additions & 15 deletions recurrentshop/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,9 @@ def get_config(self):

class RecurrentContainer(Layer):

def __init__(self, weights=None, return_sequences=False, go_backwards=False, stateful=False, readout=False, state_sync=False, decode=False, output_length=None, input_length=None, unroll=False, **kwargs):
def __init__(self, weights=None, return_sequences=False, return_states=False, go_backwards=False, stateful=False, readout=False, state_sync=False, decode=False, output_length=None, input_length=None, unroll=False, **kwargs):
self.return_sequences = return_sequences or decode
self.return_states = return_states
self.initial_weights = weights
self.go_backwards = go_backwards
self.stateful = stateful
Expand All @@ -186,6 +187,7 @@ def __init__(self, weights=None, return_sequences=False, go_backwards=False, sta
self.model = Sequential()
self.supports_masking = True
self._truth_tensor = None
self.initial_readout = None
super(RecurrentContainer, self).__init__(**kwargs)

def add(self, layer):
Expand Down Expand Up @@ -227,16 +229,23 @@ def input_shape(self):
def output_shape(self):
shape = self.model.output_shape
if self.decode:
return (shape[0], self.output_length) + shape[1:]
if self.return_sequences:
shape = (shape[0], self.output_length) + shape[1:]
elif self.return_sequences:
input_length = self.input_spec[0].shape[1]
return (shape[0], input_length) + shape[1:]
else:
return shape
shape = (shape[0], input_length) + shape[1:]
if self.return_states:
shape = [shape] + [None] * self.nb_states
return shape

def get_output_shape_for(self, input_shape):
# this is a container
return (input_shape[0],) + self.output_shape[1:]
if self.return_states:
output_shape = self.output_shape
state_shapes = output_shape[1:]
output_shape = output_shape[0]
output_shape = (input_shape[0],) + output_shape[1:]
return [output_shape] + state_shapes
else:
return (input_shape[0],) + self.output_shape[1:]

def step(self, x, states):
states = list(states)
Expand Down Expand Up @@ -279,6 +288,18 @@ def step(self, x, states):
return x, states

def call(self, x, mask=None):
if type(x) in [list, tuple]:
if 'ground_truth' in self.input_format:
self.set_truth_tensor(x[self.input_format.index('ground_truth')])
if 'initial_readout' in self.input_format:
self.initial_readout = x[self.input_format.index('initial_readout')]
if 'states' in self.input_format:
states = x[self.input_format.index('states'):]
for i in range(len(states)):
self.set_state(self.state_indices[i], states[i])
x = x[0]
if not self.initial_readout and self.readout == 'readout_only':
self.initial_readout = x
unroll = self.unroll
'''
if K.backend() == 'tensorflow':
Expand Down Expand Up @@ -318,18 +339,22 @@ def call(self, x, mask=None):
states = list(states)
if self.stateful:
for i in range(len(states)):
self.updates.append((self.states[i], states[i]))
if type(self.states[i]) == type(K.zeros((1,))):
self.updates.append((self.states[i], states[i]))
if self.decode:
states.pop(0)
if self.readout:
states.pop(-1)
if self._truth_tensor:
states.pop(-1)
self.state_outputs = states
if self.return_sequences:
return outputs
y = outputs
else:
return last_output
y = last_output
if self.return_states:
y = [y] + states
return y


def get_initial_states(self, x):
initial_states = []
Expand Down Expand Up @@ -359,7 +384,7 @@ def get_initial_states(self, x):
if self.readout:
if self._truth_tensor:
initial_states += [K.zeros((1,), dtype='int32')]
if hasattr(self, 'initial_readout'):
if self.initial_readout:
initial_readout = self._get_state_from_info(self.initial_readout, input, batch_size, input_length)
initial_states += [initial_readout]
else:
Expand All @@ -373,7 +398,7 @@ def reset_states(self):
for layer in self.model.layers:
if _isRNN(layer):
for state in layer.states:
assert type(state) in [tuple, list] or 'numpy' in str(type(state)), 'Stateful RNNs require states with static shapes'
#assert type(state) in [tuple, list] or 'numpy' in str(type(state)), 'Stateful RNNs require states with static shapes'
if 'numpy' in str(type(state)):
states += [K.variable(state)]
elif type(state) in [list, tuple]:
Expand Down Expand Up @@ -419,6 +444,14 @@ def _get_state_from_info(self, info, input, batch_size, input_length):
else:
return info

def compute_mask(self, input, input_mask=None):
if type(input_mask) is list:
input_mask = input_mask[0]
if self.return_states:
return [input_mask] + [None] * self.nb_states
else:
return input_mask

@property
def trainable_weights(self):
if not self.model.layers:
Expand Down Expand Up @@ -459,7 +492,7 @@ def set_truth_tensor(self, val):
self._truth_tensor = val

def get_config(self):
attribs = ['return_sequences', 'go_backwards', 'stateful', 'readout', 'state_sync', 'decode', 'input_length', 'unroll', 'output_length']
attribs = ['return_sequences', 'return_states', 'go_backwards', 'stateful', 'readout', 'state_sync', 'decode', 'input_length', 'unroll', 'output_length']
config = {x : getattr(self, x) for x in attribs}
config['model'] = self.model.get_config()
base_config = super(RecurrentContainer, self).get_config()
Expand All @@ -476,3 +509,91 @@ def from_config(cls, config):
layer = layer_from_config(layer_config, cells.__dict__)
rc.add(layer)
return rc

def __call__(self, x, mask=None):
args = ['input', 'ground_truth', 'initial_readout', 'states']
if type(x) is dict:
x = map(x.get, args)
elif type(x) not in [list, tuple]:
x = [x, None, None, None]
self.input_format = []
input_tensors = []
for i in range(3):
if x[i]:
self.input_format += [args[i]]
input_tensors += [x[i]]
if x[3]:
self.input_format += [args[3]]
states = []
self.state_indices = []
for i in range(len(x[3])):
if x[3][i]:
states += [x[3][i]]
self.state_indices += [i]
input_tensors += states

if not self.built:
self.assert_input_compatibility(x)
input_shapes = []
for x_elem in input_tensors:
if hasattr(x_elem, '_keras_shape'):
input_shapes.append(x_elem._keras_shape)
elif hasattr(K, 'int_shape'):
input_shapes.append(K.int_shape(x_elem))
elif x_elem:
raise Exception('You tried to call layer "' + self.name +
'". This layer has no information'
' about its expected input shape, '
'and thus cannot be built. '
'You can build it manually via: '
'`layer.build(batch_input_shape)`')
self.build(input_shapes[0])
self.built = True
self.assert_input_compatibility(x[0])
input_added = False
inbound_layers = []
node_indices = []
tensor_indices = []
self.ignore_indices = []
for i in range(len(input_tensors)):
input_tensor = input_tensors[i]
if hasattr(input_tensor, '_keras_history') and input_tensor._keras_history:
previous_layer, node_index, tensor_index = input_tensor._keras_history
inbound_layers.append(previous_layer)
node_indices.append(node_index)
tensor_indices.append(tensor_index)
else:
inbound_layers = None
break
if inbound_layers:
self.add_inbound_node(inbound_layers, node_indices, tensor_indices)
input_added = True
if input_added:
outputs = self.inbound_nodes[-1].output_tensors
if len(outputs) == 1:
return outputs[0]
else:
return outputs
else:
return self.call(x, mask)

def set_state(self, index, state):
n = 0
for layer in self.model.layers:
if _isRNN(layer):
if self.state_sync:
layer.states[index] = state
return
n += len(layer.states)
if index < n:
layer.states[index + len(layer.states) - n] = state
return

@property
def nb_states(self):
if self.state_sync:
for layer in self.model.layers:
if _isRNN(layer):
return len(layer.states)
return 0

0 comments on commit 66e4518

Please sign in to comment.