Skip to content

Commit

Permalink
merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Feb 8, 2016
2 parents 523e24e + 8d393f7 commit f27c5b0
Show file tree
Hide file tree
Showing 25 changed files with 696 additions and 410 deletions.
3 changes: 2 additions & 1 deletion docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ Our documentation uses extended Markdown, as implemented by [MkDocs](http://mkdo
- install MkDocs: `pip install mkdocs`
- `cd` to the `docs/` folder and run:
- `python autogen.py`
- `mkdocs serve`
- `mkdocs serve` # Starts a local webserver: [localhost:8000](localhost:8000)
- `mkdocs build` # Builds a static site in "site" directory
4 changes: 2 additions & 2 deletions docs/autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ def code_snippet(snippet):


def process_class_docstring(docstring):
docstring = re.sub(r' # (.*)\n',
r' __\1__\n\n',
docstring = re.sub(r'\n # (.*)\n',
r'\n __\1__\n\n',
docstring)

docstring = re.sub(r' ([^\s\\]+):(.*)\n',
Expand Down
1 change: 0 additions & 1 deletion docs/templates/objectives.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ For a few examples of such functions, check out the [objectives source](https://
## Available objectives

- __mean_squared_error__ / __mse__
- __root_mean_squared_error__ / __rmse__
- __mean_absolute_error__ / __mae__
- __mean_absolute_percentage_error__ / __mape__
- __mean_squared_logarithmic_error__ / __msle__
Expand Down
124 changes: 124 additions & 0 deletions examples/mnist_siamese_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
'''Train a Siamese MLP on pairs of digits from the MNIST dataset.
It follows Hadsell-et-al.'06 [1] by computing the Euclidean distance on the
output of the shared network and by optimizing the contrastive loss (see paper
for mode details).
[1] "Dimensionality Reduction by Learning an Invariant Mapping"
http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
Run on GPU: THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python mnist_siamese_graph.py
Gets to 99.5% test accuracy after 20 epochs.
3 seconds per epoch on a Titan X GPU
'''
from __future__ import absolute_import
from __future__ import print_function
import numpy as np
np.random.seed(1337) # for reproducibility

import random
from keras.datasets import mnist
from keras.models import Sequential, Graph
from keras.layers.core import Dense, Dropout, Lambda
from keras.optimizers import SGD, RMSprop
from keras import backend as K


def euclidean_distance(inputs):
assert len(inputs) == 2, ('Euclidean distance needs '
'2 inputs, %d given' % len(inputs))
u, v = inputs.values()
return K.sqrt(K.sum(K.square(u - v), axis=1, keepdims=True))


def contrastive_loss(y, d):
'''Contrastive loss from Hadsell-et-al.'06
http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
'''
margin = 1
return K.mean(y * K.square(d) + (1 - y) * K.square(K.maximum(margin - d, 0)))


def create_pairs(x, digit_indices):
'''Positive and negative pair creation.
Alternates between positive and negative pairs.
'''
pairs = []
labels = []
n = min([len(digit_indices[d]) for d in range(10)]) - 1
for d in range(10):
for i in range(n):
z1, z2 = digit_indices[d][i], digit_indices[d][i+1]
pairs += [[x[z1], x[z2]]]
inc = random.randrange(1, 10)
dn = (d + inc) % 10
z1, z2 = digit_indices[d][i], digit_indices[dn][i]
pairs += [[x[z1], x[z2]]]
labels += [1, 0]
return np.array(pairs), np.array(labels)


def create_base_network(input_dim):
'''Base network to be shared (eq. to feature extraction).
'''
seq = Sequential()
seq.add(Dense(128, input_shape=(input_dim,), activation='relu'))
seq.add(Dropout(0.1))
seq.add(Dense(128, activation='relu'))
seq.add(Dropout(0.1))
seq.add(Dense(128, activation='relu'))
return seq


def compute_accuracy(predictions, labels):
'''Compute classification accuracy with a fixed threshold on distances.
'''
return labels[predictions.ravel() < 0.5].mean()


# the data, shuffled and split between tran and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(60000, 784)
X_test = X_test.reshape(10000, 784)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
input_dim = 784
nb_epoch = 20

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(10)]
tr_pairs, tr_y = create_pairs(X_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(10)]
te_pairs, te_y = create_pairs(X_test, digit_indices)

# network definition
base_network = create_base_network(input_dim)

g = Graph()
g.add_input(name='input_a', input_shape=(input_dim,))
g.add_input(name='input_b', input_shape=(input_dim,))
g.add_shared_node(base_network, name='shared', inputs=['input_a', 'input_b'],
merge_mode='join')
g.add_node(Lambda(euclidean_distance), name='d', input='shared')
g.add_output(name='output', input='d')

# train
rms = RMSprop()
g.compile(loss={'output': contrastive_loss}, optimizer=rms)
g.fit({'input_a': tr_pairs[:, 0], 'input_b': tr_pairs[:, 1], 'output': tr_y},
validation_data={'input_a': te_pairs[:, 0], 'input_b': te_pairs[:, 1], 'output': te_y},
batch_size=128,
nb_epoch=nb_epoch)

# compute final accuracy on training and test sets
pred = g.predict({'input_a': tr_pairs[:, 0], 'input_b': tr_pairs[:, 1]})['output']
tr_acc = compute_accuracy(pred, tr_y)
pred = g.predict({'input_a': te_pairs[:, 0], 'input_b': te_pairs[:, 1]})['output']
te_acc = compute_accuracy(pred, te_y)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))
86 changes: 43 additions & 43 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,12 @@ def mean(x, axis=None, keepdims=False):
def any(x, axis=None, keepdims=False):
'''Bitwise reduction (logical OR).
Return array of int8 (0s and 1s).
Return array of uint8 (0s and 1s).
'''
axis = normalize_axis(axis, ndim(x))
x = tf.cast(x, tf.bool)
x = tf.reduce_any(x, reduction_indices=axis, keep_dims=keepdims)
return tf.cast(x, tf.int8)
return tf.cast(x, tf.uint8)


def argmax(x, axis=-1):
Expand Down Expand Up @@ -289,6 +289,7 @@ def repeat(x, n):
if x has shape (samples, dim) and n=2,
the output will have shape (samples, 2, dim)
'''
assert ndim(x) == 2
tensors = [x] * n
stacked = tf.pack(tensors)
return tf.transpose(stacked, (1, 0, 2))
Expand Down Expand Up @@ -429,54 +430,53 @@ def rnn(step_function, inputs, initial_states,
axes = [1, 0] + list(range(2, ndim))
inputs = tf.transpose(inputs, (axes))
input_list = tf.unpack(inputs)
if mask is None:
mask = ones_like(tf.slice(inputs, [0, 0, 0], [-1, -1, 1]))
inputs_shape = inputs.get_shape()

# TODO: the mask's shape should be automatically inferred, by
# tensorflow yet for some reason it fails to in some test-cases. This
# fixes the issue, but should be removed in future.
mask.set_shape([inputs_shape[0].value, inputs_shape[1].value, 1])
mask = tf.cast(mask, tf.bool)
else:
# Transpose not supported by bool tensor types, hence round-trip to uint8.
mask = tf.cast(tf.transpose(tf.cast(mask, tf.uint8), axes), tf.bool)

mask_list = tf.unpack(mask)

states = initial_states
successive_states = []
successive_outputs = []
if go_backwards:
input_list.reverse()

for input, mask_t in zip(input_list, mask_list):
output, new_states = step_function(input, states)

# tf.select needs its condition tensor to be the same shape as its two
# result tensors, but in our case the condition (mask) tensor is
# (nsamples, 1), and A and B are (nsamples, ndimensions). So we need to
# broadcast the mask to match the shape of A and B. That's what the
# tile call does, is just repeat the mask along its second dimension
# ndimensions times.
tiled_mask_t = tf.tile(mask_t, tf.pack([1, tf.shape(output)[1]]))

if len(successive_outputs) == 0:
prev_output = zeros_like(output)
else:
prev_output = successive_outputs[-1]

output = tf.select(tiled_mask_t, output, prev_output)

return_states = []
for state, new_state in zip(states, new_states):
# (see earlier comment for tile explanation)
tiled_mask_t = tf.tile(mask_t, tf.pack([1, tf.shape(new_state)[1]]))
return_states.append(tf.select(tiled_mask_t, new_state, state))

states = return_states
successive_outputs.append(output)
successive_states.append(states)
if mask is not None:
# Transpose not supported by bool tensor types, hence round-trip to uint8.
mask = tf.cast(mask, tf.uint8)
if len(mask.get_shape()) == ndim-1:
mask = expand_dims(mask)
mask = tf.cast(tf.transpose(mask, axes), tf.bool)
mask_list = tf.unpack(mask)

for input, mask_t in zip(input_list, mask_list):
output, new_states = step_function(input, states)

# tf.select needs its condition tensor to be the same shape as its two
# result tensors, but in our case the condition (mask) tensor is
# (nsamples, 1), and A and B are (nsamples, ndimensions). So we need to
# broadcast the mask to match the shape of A and B. That's what the
# tile call does, is just repeat the mask along its second dimension
# ndimensions times.
tiled_mask_t = tf.tile(mask_t, tf.pack([1, tf.shape(output)[1]]))

if len(successive_outputs) == 0:
prev_output = zeros_like(output)
else:
prev_output = successive_outputs[-1]

output = tf.select(tiled_mask_t, output, prev_output)

return_states = []
for state, new_state in zip(states, new_states):
# (see earlier comment for tile explanation)
tiled_mask_t = tf.tile(mask_t, tf.pack([1, tf.shape(new_state)[1]]))
return_states.append(tf.select(tiled_mask_t, new_state, state))

states = return_states
successive_outputs.append(output)
successive_states.append(states)
else:
for input in input_list:
output, states = step_function(input, states)
successive_outputs.append(output)
successive_states.append(states)

last_output = successive_outputs[-1]
outputs = tf.pack(successive_outputs)
Expand Down
27 changes: 11 additions & 16 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,8 @@ def placeholder(shape=None, ndim=None, dtype=_FLOATX, name=None):
raise Exception('Specify either a shape or ndim value.')
if shape is not None:
ndim = len(shape)
if ndim == 0:
return T.scalar(name=name, dtype=dtype)
elif ndim == 1:
return T.vector(name=name, dtype=dtype)
elif ndim == 2:
return T.matrix(name=name, dtype=dtype)
elif ndim == 3:
return T.tensor3(name=name, dtype=dtype)
elif ndim == 4:
return T.tensor4(name=name, dtype=dtype)
else:
raise Exception('ndim too large: ' + str(ndim))
broadcast = (False,) * ndim
return T.TensorType(dtype, broadcast)(name)


def shape(x):
Expand Down Expand Up @@ -281,9 +271,9 @@ def repeat(x, n):
If x has shape (samples, dim) and n=2,
the output will have shape (samples, 2, dim).
'''
tensors = [x] * n
stacked = T.stack(*tensors)
return stacked.dimshuffle((1, 0, 2))
assert x.ndim == 2
x = x.dimshuffle((0, 'x', 1))
return T.extra_ops.repeat(x, n, axis=1)


def tile(x, n):
Expand Down Expand Up @@ -427,7 +417,7 @@ def rnn(step_function, inputs, initial_states,
the step function.
go_backwards: boolean. If True, do the iteration over
the time dimension in reverse order.
mask: binary tensor with shape (samples, time, 1),
mask: binary tensor with shape (samples, time),
with a zero for every element that is masked.
Returns
Expand All @@ -447,7 +437,11 @@ def rnn(step_function, inputs, initial_states,
inputs = inputs.dimshuffle(axes)

if mask is not None:
if mask.ndim == ndim-1:
mask = expand_dims(mask)
assert mask.ndim == ndim
mask = mask.dimshuffle(axes)

# build an all-zero tensor of shape (samples, output_dim)
initial_output = step_function(inputs[0], initial_states)[0] * 0
# Theano gets confused by broadcasting patterns in the scan op
Expand Down Expand Up @@ -674,6 +668,7 @@ def pool2d(x, pool_size, strides=(1, 1), border_mode='valid',
pool_out = pool_out.dimshuffle((0, 2, 3, 1))
return pool_out


# RANDOMNESS


Expand Down
8 changes: 4 additions & 4 deletions keras/layers/advanced_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

class LeakyReLU(MaskedLayer):
'''Special version of a Rectified Linear Unit
that allows a small gradient when the unit is not active
(`f(x) = alpha*x for x < 0`).
that allows a small gradient when the unit is not active:
`f(x) = alpha*x for x < 0`.
# Input shape
Arbitrary. Use the keyword argument `input_shape`
Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(self, init='zero', weights=None, **kwargs):
def build(self):
input_shape = self.input_shape[1:]
self.alphas = self.init(input_shape)
self.params = [self.alphas]
self.trainable_weights = [self.alphas]

if self.initial_weights is not None:
self.set_weights(self.initial_weights)
Expand Down Expand Up @@ -142,7 +142,7 @@ def build(self):
input_shape = self.input_shape[1:]
self.alphas = K.variable(self.alpha_init * np.ones(input_shape))
self.betas = K.variable(self.beta_init * np.ones(input_shape))
self.params = [self.alphas, self.betas]
self.trainable_weights = [self.alphas, self.betas]

if self.initial_weights is not None:
self.set_weights(self.initial_weights)
Expand Down
Loading

0 comments on commit f27c5b0

Please sign in to comment.