Skip to content

Commit

Permalink
Refactor add_weight to align it with get_variable
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Apr 18, 2017
1 parent fc4874f commit 8830c53
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 35 deletions.
14 changes: 10 additions & 4 deletions keras/engine/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,10 +360,14 @@ def non_trainable_weights(self):
def non_trainable_weights(self, weights):
self._non_trainable_weights = weights

def add_weight(self, shape, initializer,
name=None,
trainable=True,
@interfaces.legacy_add_weight_support
def add_weight(self,
name,
shape,
dtype=None,
initializer=None,
regularizer=None,
trainable=True,
constraint=None):
"""Adds a weight variable to the layer.
Expand All @@ -381,7 +385,9 @@ def add_weight(self, shape, initializer,
The created weight variable.
"""
initializer = initializers.get(initializer)
weight = K.variable(initializer(shape), dtype=K.floatx(), name=name)
if dtype is None:
dtype = K.floatx()
weight = K.variable(initializer(shape), dtype=dtype, name=name)
if regularizer is not None:
self.add_loss(regularizer(weight))
if constraint is not None:
Expand Down
2 changes: 1 addition & 1 deletion keras/layers/advanced_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def build(self, input_shape):
for i in self.shared_axes:
param_shape[i - 1] = 1
self.param_broadcast[i - 1] = True
self.alpha = self.add_weight(param_shape,
self.alpha = self.add_weight(shape=param_shape,
name='alpha',
initializer=self.alpha_initializer,
regularizer=self.alpha_regularizer,
Expand Down
14 changes: 7 additions & 7 deletions keras/layers/convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,13 @@ def build(self, input_shape):
input_dim = input_shape[channel_axis]
kernel_shape = self.kernel_size + (input_dim, self.filters)

self.kernel = self.add_weight(kernel_shape,
self.kernel = self.add_weight(shape=kernel_shape,
initializer=self.kernel_initializer,
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.use_bias:
self.bias = self.add_weight((self.filters,),
self.bias = self.add_weight(shape=(self.filters,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
Expand Down Expand Up @@ -721,13 +721,13 @@ def build(self, input_shape):
input_dim = input_shape[channel_axis]
kernel_shape = self.kernel_size + (self.filters, input_dim)

self.kernel = self.add_weight(kernel_shape,
self.kernel = self.add_weight(shape=kernel_shape,
initializer=self.kernel_initializer,
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.use_bias:
self.bias = self.add_weight((self.filters,),
self.bias = self.add_weight(shape=(self.filters,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
Expand Down Expand Up @@ -952,20 +952,20 @@ def build(self, input_shape):
self.filters)

self.depthwise_kernel = self.add_weight(
depthwise_kernel_shape,
shape=depthwise_kernel_shape,
initializer=self.depthwise_initializer,
name='depthwise_kernel',
regularizer=self.depthwise_regularizer,
constraint=self.depthwise_constraint)
self.pointwise_kernel = self.add_weight(
pointwise_kernel_shape,
shape=pointwise_kernel_shape,
initializer=self.pointwise_initializer,
name='pointwise_kernel',
regularizer=self.pointwise_regularizer,
constraint=self.pointwise_constraint)

if self.use_bias:
self.bias = self.add_weight((self.filters,),
self.bias = self.add_weight(shape=(self.filters,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
Expand Down
6 changes: 3 additions & 3 deletions keras/layers/convolutional_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,19 +351,19 @@ def build(self, input_shape):
self.kernel_shape = kernel_shape
recurrent_kernel_shape = self.kernel_size + (self.filters, self.filters * 4)

self.kernel = self.add_weight(kernel_shape,
self.kernel = self.add_weight(shape=kernel_shape,
initializer=self.kernel_initializer,
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
self.recurrent_kernel = self.add_weight(
recurrent_kernel_shape,
shape=recurrent_kernel_shape,
initializer=self.recurrent_initializer,
name='recurrent_kernel',
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
if self.use_bias:
self.bias = self.add_weight((self.filters * 4,),
self.bias = self.add_weight(shape=(self.filters * 4,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
Expand Down
4 changes: 2 additions & 2 deletions keras/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,13 +820,13 @@ def build(self, input_shape):
assert len(input_shape) >= 2
input_dim = input_shape[-1]

self.kernel = self.add_weight((input_dim, self.units),
self.kernel = self.add_weight(shape=(input_dim, self.units),
initializer=self.kernel_initializer,
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.use_bias:
self.bias = self.add_weight((self.units,),
self.bias = self.add_weight(shape=(self.units,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
Expand Down
2 changes: 1 addition & 1 deletion keras/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(self, input_dim, output_dim,

def build(self, input_shape):
self.embeddings = self.add_weight(
(self.input_dim, self.output_dim),
shape=(self.input_dim, self.output_dim),
initializer=self.embeddings_initializer,
name='embeddings',
regularizer=self.embeddings_regularizer,
Expand Down
8 changes: 4 additions & 4 deletions keras/layers/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,14 @@ def build(self, input_shape):
self.kernel_size[0] * input_dim,
self.filters)
self.kernel = self.add_weight(
self.kernel_shape,
shape=self.kernel_shape,
initializer=self.kernel_initializer,
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.use_bias:
self.bias = self.add_weight(
(output_length, self.filters),
shape=(output_length, self.filters),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
Expand Down Expand Up @@ -325,13 +325,13 @@ def build(self, input_shape):
self.kernel_shape = (output_row * output_col,
self.kernel_size[0] * self.kernel_size[1] * input_filter,
self.filters)
self.kernel = self.add_weight(self.kernel_shape,
self.kernel = self.add_weight(shape=self.kernel_shape,
initializer=self.kernel_initializer,
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.use_bias:
self.bias = self.add_weight((output_row, output_col, self.filters),
self.bias = self.add_weight(shape=(output_row, output_col, self.filters),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
Expand Down
8 changes: 4 additions & 4 deletions keras/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,28 +96,28 @@ def build(self, input_shape):
shape = (dim,)

if self.scale:
self.gamma = self.add_weight(shape,
self.gamma = self.add_weight(shape=shape,
name='gamma',
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
else:
self.gamma = None
if self.center:
self.beta = self.add_weight(shape,
self.beta = self.add_weight(shape=shape,
name='beta',
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)
else:
self.beta = None
self.moving_mean = self.add_weight(
shape,
shape=shape,
name='moving_mean',
initializer=self.moving_mean_initializer,
trainable=False)
self.moving_variance = self.add_weight(
shape,
shape=shape,
name='moving_variance',
initializer=self.moving_variance_initializer,
trainable=False)
Expand Down
18 changes: 9 additions & 9 deletions keras/layers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,19 +471,19 @@ def build(self, input_shape):
if self.stateful:
self.reset_states()

self.kernel = self.add_weight((self.input_dim, self.units),
self.kernel = self.add_weight(shape=(self.input_dim, self.units),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
self.recurrent_kernel = self.add_weight(
(self.units, self.units),
shape=(self.units, self.units),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
if self.use_bias:
self.bias = self.add_weight((self.units,),
self.bias = self.add_weight(shape=(self.units,),
name='bias',
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
Expand Down Expand Up @@ -690,20 +690,20 @@ def build(self, input_shape):
if self.stateful:
self.reset_states()

self.kernel = self.add_weight((self.input_dim, self.units * 3),
self.kernel = self.add_weight(shape=(self.input_dim, self.units * 3),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
self.recurrent_kernel = self.add_weight(
(self.units, self.units * 3),
shape=(self.units, self.units * 3),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)

if self.use_bias:
self.bias = self.add_weight((self.units * 3,),
self.bias = self.add_weight(shape=(self.units * 3,),
name='bias',
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
Expand Down Expand Up @@ -970,20 +970,20 @@ def build(self, input_shape):
if self.stateful:
self.reset_states()

self.kernel = self.add_weight((self.input_dim, self.units * 4),
self.kernel = self.add_weight(shape=(self.input_dim, self.units * 4),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
self.recurrent_kernel = self.add_weight(
(self.units, self.units * 4),
shape=(self.units, self.units * 4),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)

if self.use_bias:
self.bias = self.add_weight((self.units * 4,),
self.bias = self.add_weight(shape=(self.units * 4,),
name='bias',
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
Expand Down
17 changes: 17 additions & 0 deletions keras/legacy/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,3 +606,20 @@ def generator_methods_args_preprocessor(args, kwargs):
legacy_input_support = generate_legacy_interface(
allowed_positional_args=None,
conversions=[('input_dtype', 'dtype')])


def add_weight_args_preprocessing(args, kwargs):
if len(args) > 1:
if isinstance(args[1], (tuple, list)):
kwargs['shape'] = args[1]
args = (args[0],) + args[2:]
if len(args) > 1:
if isinstance(args[1], six.string_types):
kwargs['name'] = args[1]
args = (args[0],) + args[2:]
return args, kwargs, []


legacy_add_weight_support = generate_legacy_interface(
allowed_positional_args=['name', 'shape'],
preprocessor=add_weight_args_preprocessing)

0 comments on commit 8830c53

Please sign in to comment.