Skip to content

Commit

Permalink
Bug fix: Batch dot (keras-team#11458)
Browse files Browse the repository at this point in the history
* rewrite batch dot for cntk

* tf batch dot minor fix

* better tests

* pep8

* minor fix

* rewrite tf batch_dot

* docstrings

* rem print

* typo

* update to latest cntk

* add batch_dot(var, var) support

* fix theano batch_dot for 2d inputs

* fix batch_dot in ref ops

* better handling of none axis

* docstring fix

* fix comment

* changes for @gabrieldemarmiesse

* ws fix

Co-Authored-By: farizrahman4u <[email protected]>
  • Loading branch information
farizrahman4u authored and taehoonlee committed Oct 26, 2018
1 parent 3348712 commit 9148325
Show file tree
Hide file tree
Showing 5 changed files with 343 additions and 120 deletions.
162 changes: 133 additions & 29 deletions keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,43 +570,142 @@ def batch_dot(x, y, axes=None):
x_shape = int_shape(x)
y_shape = int_shape(y)

x_ndim = len(x_shape)
y_ndim = len(y_shape)

if x_ndim < 2 or y_ndim < 2:
raise ValueError('Can not do batch_dot on inputs '
'with rank < 2. '
'Received inputs with shapes ' +
str(x_shape) + ' and ' +
str(y_shape) + '.')

x_batch_size = x_shape[0]
y_batch_size = y_shape[0]

if x_batch_size is not None and y_batch_size is not None:
if x_batch_size != y_batch_size:
raise ValueError('Can not do batch_dot on inputs '
'with different batch sizes. '
'Received inputs with shapes ' +
str(x_shape) + ' and ' +
str(y_shape) + '.')

if isinstance(axes, int):
axes = (axes, axes)
axes = [axes, axes]

if axes is None:
# behaves like tf.batch_matmul as default
axes = [len(x_shape) - 1, len(y_shape) - 2]
if y_ndim == 2:
axes = [x_ndim - 1, y_ndim - 1]
else:
axes = [x_ndim - 1, y_ndim - 2]

if b_any([isinstance(a, (list, tuple)) for a in axes]):
raise ValueError('Multiple target dimensions are not supported. ' +
'Expected: None, int, (int, int), ' +
'Provided: ' + str(axes))

if len(x_shape) == 2 and len(y_shape) == 2:
if axes[0] == axes[1]:
result = sum(x * y, axis=axes[0], keepdims=True)
return result if axes[0] == 1 else transpose(result)
else:
return sum(x * transpose(y), axis=axes[0], keepdims=True)
# if tuple, convert to list
axes = list(axes)

# convert negative indices
if axes[0] < 0:
axes[0] += x_ndim
if axes[1] < 0:
axes[1] += y_ndim

if 0 in axes:
raise ValueError('Can not perform batch_dot over axis 0.'
' If your inputs are not batched,'
' add a dummy batch dimension to your '
'inputs using K.expand_dims(x, 0)')
d1 = x_shape[axes[0]]
d2 = y_shape[axes[1]]

if d1 is not None and d2 is not None and d1 != d2:
raise ValueError('Can not do batch_dot on inputs with shapes ' +
str(x_shape) + ' and ' + str(y_shape) +
' with axes=' + str(axes) + '. x.shape[%d] != '
'y.shape[%d] (%d != %d).' % (axes[0], axes[1], d1, d2))

# Input shapes:
# x: (b_size, x1, ..., d, ..., xn)
# y: (b_size, y1, ..., d, ..., yn)
# where d is the dimension to reduce.

# Bring d to the last dimension in x
# x: (b_size, ..., d)

permute_pattern = list(range(x_ndim))
for i in range(axes[0], x_ndim - 1):
permute_pattern[i] = permute_pattern[i + 1]
permute_pattern[-1] = axes[0]

x = permute_dimensions(x, permute_pattern)

# Bring d to the second dimension in y
# y: (b_size, d, ...)
permute_pattern = list(range(y_ndim))

for i in range(axes[1], 1, -1):
permute_pattern[i] = permute_pattern[i - 1]
permute_pattern[1] = axes[1]
y = permute_dimensions(y, permute_pattern)

# Expand to rank 3 if needed
if x_ndim == 2:
x = expand_dims(x, 1)
x_expanded = True
else:
if len(y_shape) == 2:
y = expand_dims(y)

normalized_axis = []
normalized_axis.append(_normalize_axis(axes[0], x)[0])
normalized_axis.append(_normalize_axis(axes[1], y)[0])
# transpose
i = normalized_axis[0]
while i < len(x.shape) - 1:
x = C.swapaxes(x, i, i + 1)
i += 1
i = normalized_axis[1]
while i > 0:
y = C.swapaxes(y, i, i - 1)
i -= 1
result = C.times(x, y, output_rank=(len(y.shape) - 1)
if len(y.shape) > 1 else 1)
if len(y_shape) == 2:
result = squeeze(result, -1)
return result
x_expanded = False

if y_ndim == 2:
y = expand_dims(y, -1)
y_expanded = True
else:
y_expanded = False

x_shape = int_shape(x)
y_shape = int_shape(y)

# batch size might be lost at this point
x_batch_size = x_shape[0]
y_batch_size = y_shape[0]

if x_batch_size is None and y_batch_size is None:
dynamic_batch_size = True
elif x_batch_size is not None and y_batch_size is not None:
dynamic_batch_size = False
else:
raise ValueError('Can not perform batch_dot on inputs' +
' with both static and dynamic batch sizes.' +
'You probably attempted to permform the ' +
'operation on a placeholder and a variable, ' +
'which is not yet supported on the CNTK backend.')

if dynamic_batch_size:
result = C.times(x, y, output_rank=y_ndim - 2 + int(y_expanded))
else:
result = []

for i in range(x_batch_size):
xi = x[i]
yi = y[i]
if ndim(xi) == ndim(x): # for older versions of CNTK
xi = squeeze(xi, 0)
yi = squeeze(yi, 0)
result.append(C.times(xi, yi, output_rank=y_ndim - 2 + int(y_expanded)))
result = stack(result, 0)

if x_expanded:
result = squeeze(result, 1)

if y_expanded:
result = squeeze(result, -1)

if ndim(result) == 1:
return expand_dims(result)
return result


def transpose(x):
Expand Down Expand Up @@ -1133,6 +1232,11 @@ def concatenate(tensors, axis=-1):
return C.splice(*tensors, axis=axis[0])


def stack(x, axis=0):
x = [expand_dims(t, axis) for t in x]
return concatenate(x, axis)


def flatten(x):
return reshape(x, (-1,))

Expand Down
138 changes: 99 additions & 39 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,7 @@ def batch_dot(x, y, axes=None):
"""Batchwise dot product.
`batch_dot` is used to compute dot product of `x` and `y` when
`x` and `y` are data in batch, i.e. in a shape of
`x` and `y` are data in batches, i.e. in a shape of
`(batch_size, :)`.
`batch_dot` results in a tensor or variable with less dimensions
than the input. If the number of dimensions is reduced to 1,
Expand All @@ -1099,8 +1099,7 @@ def batch_dot(x, y, axes=None):
# Arguments
x: Keras tensor or variable with `ndim >= 2`.
y: Keras tensor or variable with `ndim >= 2`.
axes: list of (or single) int with target dimensions.
The lengths of `axes[0]` and `axes[1]` should be the same.
axes: int or tupe(int, int). Target dimensions to be reduced.
# Returns
A tensor with shape equal to the concatenation of `x`'s shape
Expand All @@ -1114,6 +1113,14 @@ def batch_dot(x, y, axes=None):
of `x.dot(y.T)`, although we never have to calculate the off-diagonal
elements.
Pseudocode:
```
inner_products = []
for xi, yi in zip(x, y):
inner_products.append(xi.dot(yi))
result = stack(inner_prodcuts)
```
Shape inference:
Let `x`'s shape be `(100, 20)` and `y`'s shape be `(100, 30, 20)`.
If `axes` is (1, 2), to find the output shape of resultant tensor,
Expand All @@ -1132,52 +1139,105 @@ def batch_dot(x, y, axes=None):
```python
>>> x_batch = K.ones(shape=(32, 20, 1))
>>> y_batch = K.ones(shape=(32, 30, 20))
>>> xy_batch_dot = K.batch_dot(x_batch, y_batch, axes=[1, 2])
>>> xy_batch_dot = K.batch_dot(x_batch, y_batch, axes=(1, 2))
>>> K.int_shape(xy_batch_dot)
(32, 1, 30)
```
"""
x_shape = int_shape(x)
y_shape = int_shape(y)

x_ndim = len(x_shape)
y_ndim = len(y_shape)

if x_ndim < 2 or y_ndim < 2:
raise ValueError('Can not do batch_dot on inputs '
'with rank < 2. '
'Received inputs with shapes ' +
str(x_shape) + ' and ' +
str(y_shape) + '.')

x_batch_size = x_shape[0]
y_batch_size = y_shape[0]

if x_batch_size is not None and y_batch_size is not None:
if x_batch_size != y_batch_size:
raise ValueError('Can not do batch_dot on inputs '
'with different batch sizes. '
'Received inputs with shapes ' +
str(x_shape) + ' and ' +
str(y_shape) + '.')

if isinstance(axes, int):
axes = (axes, axes)
x_ndim = ndim(x)
y_ndim = ndim(y)
axes = [axes, axes]

if axes is None:
# behaves like tf.batch_matmul as default
axes = [x_ndim - 1, y_ndim - 2]
if y_ndim == 2:
axes = [x_ndim - 1, y_ndim - 1]
else:
axes = [x_ndim - 1, y_ndim - 2]

if py_any([isinstance(a, (list, tuple)) for a in axes]):
raise ValueError('Multiple target dimensions are not supported. ' +
'Expected: None, int, (int, int), ' +
'Provided: ' + str(axes))
if x_ndim > y_ndim:
diff = x_ndim - y_ndim
y = tf.reshape(y, tf.concat([tf.shape(y), [1] * (diff)], axis=0))
elif y_ndim > x_ndim:
diff = y_ndim - x_ndim
x = tf.reshape(x, tf.concat([tf.shape(x), [1] * (diff)], axis=0))
else:
diff = 0
if ndim(x) == 2 and ndim(y) == 2:
if axes[0] == axes[1]:
out = tf.reduce_sum(tf.multiply(x, y), axes[0])
else:
out = tf.reduce_sum(tf.multiply(tf.transpose(x, [1, 0]), y), axes[1])
else:
if axes is not None:
adj_x = None if axes[0] == ndim(x) - 1 else True
adj_y = True if axes[1] == ndim(y) - 1 else None
else:
adj_x = None
adj_y = None
out = tf.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
if diff:
if x_ndim > y_ndim:
idx = x_ndim + y_ndim - 3
else:
idx = x_ndim - 1
out = tf.squeeze(out, list(range(idx, idx + diff)))
if ndim(out) == 1:
out = expand_dims(out, 1)
return out

# if tuple, convert to list
axes = list(axes)

# convert negative indices
if axes[0] < 0:
axes[0] += x_ndim
if axes[1] < 0:
axes[1] += y_ndim

# sanity checks
if 0 in axes:
raise ValueError('Can not perform batch_dot over axis 0.'
'If your inputs are not batched,'
' add a dummy batch dimension to your '
'inputs using K.expand_dims(x, 0)')

a0, a1 = axes
d1 = x_shape[a0]
d2 = y_shape[a1]

if d1 is not None and d2 is not None and d1 != d2:
raise ValueError('Can not do batch_dot on inputs with shapes ' +
str(x_shape) + ' and ' + str(y_shape) +
' with axes=' + str(axes) + '. x.shape[%d] != '
'y.shape[%d] (%d != %d).' % (axes[0], axes[1], d1, d2))

# bring the dimensions to be reduced to axis 1
if a0 != 1:
pattern = list(range(x_ndim))
for i in range(a0, 1, -1):
pattern[i] = pattern[i - 1]
pattern[1] = a0
x = permute_dimensions(x, pattern)
if a1 != 1:
pattern = list(range(y_ndim))
for i in range(a1, 1, -1):
pattern[i] = pattern[i - 1]
pattern[1] = a1
y = permute_dimensions(y, pattern)

# reshape to closest broadcastable shape
x_shape = tf.shape(x)
y_shape = tf.shape(y)

new_x_shape = tf.concat([x_shape, tf.ones_like(y_shape[2:])], 0)
new_y_shape = tf.concat([y_shape[:2], tf.ones_like(x_shape[2:]), y_shape[2:]], 0)

x = reshape(x, new_x_shape)
y = reshape(y, new_y_shape)

result = tf.reduce_sum(x * y, 1)

if ndim(result) == 1:
result = tf.expand_dims(result, -1)

return result


def transpose(x):
Expand Down
18 changes: 9 additions & 9 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,22 +465,22 @@ def batch_dot(x, y, axes=None):
axes = (axes, axes)
if axes is None:
# behaves like tf.batch_matmul as default
axes = [x.ndim - 1, y.ndim - 2]
if y.ndim == 2:
axes = [x.ndim - 1, y.ndim - 1]
else:
axes = [x.ndim - 1, y.ndim - 2]
if py_any([isinstance(a, (list, tuple)) for a in axes]):
raise ValueError('Multiple target dimensions are not supported. ' +
'Expected: None, int, (int, int), ' +
'Provided: ' + str(axes))
if isinstance(axes, tuple):
axes = list(axes)

# workaround because theano doesn't accept axes
# which contains the batch axis (0)
if axes[0] == 0:
x = transpose(x)
axes[0] = x.ndim - 1
if axes[1] == 0:
y = transpose(y)
axes[1] = y.ndim - 1
if 0 in axes:
raise ValueError('Can not perform batch_dot over axis 0.'
'If your inputs are not batched,'
' add a dummy batch dimension to your '
'inputs using K.expand_dims(x, 0)')

out = T.batched_tensordot(x, y, axes=axes)
if ndim(out) == 1:
Expand Down
Loading

0 comments on commit 9148325

Please sign in to comment.