Skip to content

Commit

Permalink
Add some clarification comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
wxs committed Jul 2, 2015
1 parent 42497d9 commit c9fd2c8
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions keras/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,25 @@ def get_input(self, train):
return self.input

def supports_masked_input(self):
''' Whether or not this layer respects the output mask of its previous layer in its calculations. If you try
to attach a layer that does *not* support masked_input to a layer that gives a non-None output_mask() that is
an error'''
return False

def get_output_mask(self, train=None):
'''
For some models (such as RNNs) you want a way of being able to mark some output data-points as
"masked", so they are not used in future calculations. In such a model, get_output_mask() should return a mask
of one less dimension than get_output() (so if get_output is (nb_samples, nb_timesteps, nb_dimensions), then the mask
is (nb_samples, nb_timesteps), with a one for every unmasked datapoint, and a zero for every masked one.
If there is *no* masking then it shall return None. For instance if you attach an Activation layer (they support masking)
to a layer with an output_mask, then that Activation shall also have an output_mask. If you attach it to a layer with no
such mask, then the Activation's get_output_mask shall return None.
Some layers have an output_mask even if their input is unmasked, notably Embedding which can turn the entry "0" into
a mask.
'''
return None

def set_weights(self, weights):
Expand Down Expand Up @@ -76,6 +92,10 @@ def get_params(self):
return self.params, regularizers, consts

class MaskedLayer(Layer):
'''
If your layer trivially supports masking (by simply copying the input mask to the output), then subclass MaskedLayer
instead of Layer, and make sure that you incorporate the input mask into your calculation of get_output()
'''
def supports_masked_input(self):
return True

Expand Down

0 comments on commit c9fd2c8

Please sign in to comment.