Skip to content

Commit

Permalink
debugging first attempt, need to fix shapes..
Browse files Browse the repository at this point in the history
  • Loading branch information
andhus committed Aug 1, 2017
1 parent 154937e commit 6d262e2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
16 changes: 10 additions & 6 deletions examples/attention/recurrent_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,29 @@
from keras.engine import Model
from keras.layers import Dense, SimpleRNN, TimeDistributed

from extkeras.layers.attention import RecurrentAttentionWrapper
from extkeras.layers.attention import (
RecurrentAttentionWrapper,
DenseStatelessAttention
)

n_timesteps = 7
n_features = 5
n_features_attention = 2
n_samples = 1000

features = Input((n_timesteps, n_features))
attended = Input((n_features, )) # TODO same as number of features for now due to test hack...
attended = Input((n_features_attention, ))

recurrent_layer = SimpleRNN(units=4)
attention_model = Dense(units=4)
recurrent_layer = SimpleRNN(units=4, implementation=1)
attention_model = DenseStatelessAttention(units=3)

rnn = RecurrentAttentionWrapper(
attention_layer=attention_model,
recurrent_layer=recurrent_layer
)
output_layer = Dense(1, activation='sigmoid')

last_state = rnn(features, attended=attended)
last_state = rnn([features, attended])
output = output_layer(last_state)

model = Model(
Expand All @@ -33,7 +37,7 @@
)

features_data = np.random.randn(n_samples, n_timesteps, n_features)
attended_data = np.ones((n_samples, n_features), dtype=float)
attended_data = np.ones((n_samples, n_features_attention), dtype=float)
attended_data[::2] = 0.
target_data = attended_data.mean(axis=1, keepdims=True)

Expand Down
19 changes: 14 additions & 5 deletions extkeras/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


from keras import backend as K
from keras.engine import InputSpec
from keras.layers import Dense, concatenate
from keras.layers.recurrent import Recurrent
from keras.layers.recurrent import SimpleRNN as _SimpleRNN
Expand All @@ -22,6 +23,8 @@ def __init__(self, attention_layer, recurrent_layer):
recurrent_layer
)
self._attended = None
self.input_spec = [InputSpec(ndim=3), None]
# later should be set to attention_layer.input_spec

@property
def states(self):
Expand All @@ -43,10 +46,14 @@ def compute_output_shape(self, input_shape):
def build(self, input_shape):
[input_shape, attended_shape] = input_shape
wrapped_recurrent_step_input_shape = input_shape[:1] + input_shape[-1:]

wrapped_recurrent_state_shapes = [
(input_shape[0], self.recurrent_layer.units)
for _ in self.recurrent_layer.states
]
input_shape[:1] + spec.shape[1:]
for spec in self.recurrent_layer.state_spec
] if isinstance(self.recurrent_layer.state_spec, list) else [(
input_shape[:1] + self.recurrent_layer.state_spec.shape[1:]
)]

self.attention_layer.build(
attended_shape,
wrapped_recurrent_step_input_shape,
Expand All @@ -59,6 +66,8 @@ def build(self, input_shape):
)[-1:]
)
self.recurrent_layer.build(wrapped_recurrent_input_shape)

self.input_spec = [InputSpec(ndim=3), InputSpec(shape=attended_shape)]
self.built = True

def call(
Expand Down Expand Up @@ -91,7 +100,7 @@ def step(self, inputs, states):
wrapped_recurrent_input = self.attention_layer.attention_step(
attended=attended,
recurrent_input=inputs,
recurrent_states=states,
recurrent_states=list(states[:-2]),
attention_states=[] # TODO fix!
)
return self.recurrent_layer.step(wrapped_recurrent_input, states)
Expand Down Expand Up @@ -168,7 +177,7 @@ def build(
(
attended_shape[1] +
recurrent_step_input_shape[1] +
sum(s[1] for s in recurrent_state_shapes)
sum([s[1] for s in recurrent_state_shapes])
)
)
)
Expand Down

0 comments on commit 6d262e2

Please sign in to comment.