Skip to content

Commit

Permalink
cleaned up examples
Browse files Browse the repository at this point in the history
  • Loading branch information
andhus committed Jul 27, 2017
1 parent c4f9dfa commit 8932d18
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 21 deletions.
15 changes: 5 additions & 10 deletions examples/cell_masked_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

from extkeras.layers.recurrent import CellMaskedLSTM

n_timesteps = 5
n_features = 3
n_units = 2

features_and_cell_mask = Input((5, 5))
features_and_cell_mask = Input((n_timesteps, n_features + n_units))
cell_masked_lstm = CellMaskedLSTM(
units=n_units,
output_cells=True, # to be able to inspect that the cells are masked
Expand All @@ -17,15 +19,7 @@
state_cell_sequence = cell_masked_lstm(features_and_cell_mask)
model = Model(inputs=features_and_cell_mask, outputs=state_cell_sequence)

features_data = np.array([
[
[1.0, 2.0, 3.0],
[1.0, 2.0, 3.0],
[1.0, 2.0, 3.0],
[1.0, 2.0, 3.0],
[1.0, 2.0, 3.0],
]
])
features_data = np.random.randn(1, n_timesteps, n_features)
cell_mask_data = np.array([
[
[0., 0.],
Expand All @@ -35,6 +29,7 @@
[0., 1.],
]
])
assert cell_mask_data.shape == (1, n_timesteps, n_units)
feature_and_cell_mask_data = np.concatenate([features_data, cell_mask_data], 2)

state_cell_sequence_data = model.predict(feature_and_cell_mask_data)
Expand Down
19 changes: 8 additions & 11 deletions examples/phased_lstm_cell_mask.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
from __future__ import division, print_function


import numpy as np
from keras.engine import Input
from keras.engine import Model
from keras.layers import TimeDistributed
from extkeras.layers.recurrent import PhasedLSTMCellMask

n_samples = 3
sequence_length = 10
units = 4
n_timesteps = 10
n_units = 4

time = Input((sequence_length, 1))
time = Input((n_timesteps, 1))

cell_mask_layer = TimeDistributed(PhasedLSTMCellMask(units))
cell_mask_layer = TimeDistributed(PhasedLSTMCellMask(n_units))
cell_mask = cell_mask_layer(time)
model = Model(inputs=time, outputs=cell_mask)

model = Model(
inputs=time,
outputs=cell_mask
)

time_arr = np.arange(sequence_length).reshape((1, -1, 1)).repeat(n_samples, axis=0)
time_arr = np.arange(n_timesteps).reshape((1, -1, 1)).repeat(n_samples, axis=0)
mask = model.predict(time_arr)

# TODO add plot of mask!

0 comments on commit 8932d18

Please sign in to comment.