Skip to content

Commit

Permalink
[Numpy] Add broadcast_to scalar case (apache#17233)
Browse files Browse the repository at this point in the history
* fix broadcast_to scalar case

* change broadcast_to namespace prefix

* remove unused import
  • Loading branch information
xidulu authored and haojin2 committed Jan 8, 2020
1 parent ba376af commit f17d19b
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 43 deletions.
35 changes: 0 additions & 35 deletions python/mxnet/_numpy_op_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,41 +1132,6 @@ def _np__random_shuffle(x):
pass


def _np_broadcast_to(array, shape, out=None):
"""
Broadcast an array to a new shape.
Parameters
----------
array : ndarray
The array to broadcast.
shape : tuple, optional, default=[]
The shape of the desired array.
out : ndarray, optional
The output ndarray to hold the result.
Returns
-------
out : ndarray or list of ndarrays
Raises
------
MXNetError
- If the array is not compatible with the new shape according to NumPy's
broadcasting rules.
- If the shape of the output array is not consistent with the desired shape.
Examples
--------
>>> x = np.array([1, 2, 3])
>>> np.broadcast_to(x, (3, 3))
array([[1., 2., 3.],
[1., 2., 3.],
[1., 2., 3.]])
"""
pass


def _npx_reshape(a, newshape, reverse=False, order='C'):
"""
Gives a new shape to an array without changing its data.
Expand Down
32 changes: 31 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..ndarray import NDArray

__all__ = ['shape', 'zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like', 'empty_like', 'invert', 'delete',
'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'bitwise_not',
'add', 'broadcast_to', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'bitwise_not',
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs',
'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2',
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram',
Expand Down Expand Up @@ -259,6 +259,36 @@ def ones_like(a, dtype=None, order='C', ctx=None, out=None):
return _npi.full_like(a, fill_value=1, dtype=dtype, ctx=None, out=None)


@set_module('mxnet.ndarray.numpy')
def broadcast_to(array, shape):
"""
Broadcast an array to a new shape.
Parameters
----------
array : ndarray or scalar
The array to broadcast.
shape : tuple
The shape of the desired array.
Returns
-------
broadcast : array
A readonly view on the original array with the given shape. It is
typically not contiguous. Furthermore, more than one element of a
broadcasted array may refer to a single memory location.
Raises
------
MXNetError
If the array is not compatible with the new shape according to NumPy's
broadcasting rules.
"""
if _np.isscalar(array):
return full(shape, array)
return _npi.broadcast_to(array, shape)


@set_module('mxnet.ndarray.numpy')
def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylint: disable=too-many-arguments
"""
Expand Down
32 changes: 30 additions & 2 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from ..ndarray.ndarray import _storage_type

__all__ = ['ndarray', 'empty', 'empty_like', 'array', 'shape',
'zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like',
'zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like', 'broadcast_to',
'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'bitwise_not', 'delete',
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'invert',
'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
Expand Down Expand Up @@ -1998,7 +1998,7 @@ def squeeze(self, axis=None): # pylint: disable=arguments-differ
return _mx_np_op.squeeze(self, axis=axis)

def broadcast_to(self, shape): # pylint: disable=redefined-outer-name
return _mx_np_op.broadcast_to(self, shape)
return _mx_nd_np.broadcast_to(self, shape)

def broadcast_like(self, other):
raise AttributeError('mxnet.numpy.ndarray object has no attribute broadcast_like')
Expand Down Expand Up @@ -2283,6 +2283,34 @@ def ones(shape, dtype=_np.float32, order='C', ctx=None): # pylint: disable=rede
return _mx_nd_np.ones(shape, dtype, order, ctx)


@set_module('mxnet.numpy')
def broadcast_to(array, shape): # pylint: disable=redefined-outer-name
"""
Broadcast an array to a new shape.
Parameters
----------
array : ndarray or scalar
The array to broadcast.
shape : tuple
The shape of the desired array.
Returns
-------
broadcast : array
A readonly view on the original array with the given shape. It is
typically not contiguous. Furthermore, more than one element of a
broadcasted array may refer to a single memory location.
Raises
------
MXNetError
If the array is not compatible with the new shape according to NumPy's
broadcasting rules.
"""
return _mx_nd_np.broadcast_to(array, shape)


# pylint: disable=too-many-arguments, redefined-outer-name
@set_module('mxnet.numpy')
def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None):
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/numpy/stride_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""Util functions with broadcast."""

from ..ndarray.ndarray import _get_broadcast_shape
from . import _op as _mx_np_op
from ..ndarray import numpy as _mx_nd_np


__all__ = ['broadcast_arrays']
Expand Down Expand Up @@ -62,4 +62,4 @@ def broadcast_arrays(*args):
# Common case where nothing needs to be broadcasted.
return list(args)

return [_mx_np_op.broadcast_to(array, shape) for array in args]
return [_mx_nd_np.broadcast_to(array, shape) for array in args]
32 changes: 31 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from builtins import slice as py_slice

__all__ = ['zeros', 'zeros_like', 'ones', 'ones_like', 'full_like', 'empty_like', 'bitwise_not', 'invert', 'delete',
'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'arctan2',
'add', 'broadcast_to', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'arctan2',
'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp',
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram',
Expand Down Expand Up @@ -1132,6 +1132,36 @@ def bitwise_not(x, out=None, **kwargs):
return _unary_func_helper(x, _npi.bitwise_not, _np.bitwise_not, out=out, **kwargs)


@set_module('mxnet.symbol.numpy')
def broadcast_to(array, shape):
"""
Broadcast an array to a new shape.
Parameters
----------
array : _Symbol or scalar
The array to broadcast.
shape : tuple
The shape of the desired array.
Returns
-------
broadcast : array
A readonly view on the original array with the given shape. It is
typically not contiguous. Furthermore, more than one element of a
broadcasted array may refer to a single memory location.
Raises
------
MXNetError
If the array is not compatible with the new shape according to NumPy's
broadcasting rules.
"""
if _np.isscalar(array):
return full(shape, array)
return _npi.broadcast_to(array, shape)


@set_module('mxnet.symbol.numpy')
def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylint: disable=too-many-arguments
"""
Expand Down
2 changes: 1 addition & 1 deletion src/operator/numpy/np_broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ bool NumpyBroadcastToShape(const nnvm::NodeAttrs& attrs,
return true;
}

NNVM_REGISTER_OP(_np_broadcast_to)
NNVM_REGISTER_OP(_npi_broadcast_to)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
Expand Down
2 changes: 1 addition & 1 deletion src/operator/numpy/np_broadcast_reduce_op_value.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ NNVM_REGISTER_OP(_npi_std)
NNVM_REGISTER_OP(_npi_var)
.set_attr<FCompute>("FCompute<gpu>", NumpyMomentsForward<gpu, false>);

NNVM_REGISTER_OP(_np_broadcast_to)
NNVM_REGISTER_OP(_npi_broadcast_to)
.set_attr<FCompute>("FCompute<gpu>", NumpyBroadcastToForward<gpu>);

NNVM_REGISTER_OP(_backward_np_broadcast_to)
Expand Down
20 changes: 20 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,15 @@ def __init__(self, dst_shape):
def hybrid_forward(self, F, x):
return F.np.broadcast_to(x, self._dst_shape)

class TestScalarBroadcastTo(HybridBlock):
def __init__(self, scalar, dst_shape):
super(TestScalarBroadcastTo, self).__init__()
self._scalar = scalar
self._dst_shape = dst_shape

def hybrid_forward(self, F, x):
return F.np.broadcast_to(self._scalar, self._dst_shape)

shapes = [
((), (1, 2, 4, 5)),
((1,), (4, 5, 6)),
Expand All @@ -1558,6 +1567,17 @@ def hybrid_forward(self, F, x):
expected_grad = collapse_sum_like(_np.ones_like(expected_ret), src_shape)
assert_almost_equal(a_mx.grad.asnumpy(), expected_grad, rtol=1e-5, atol=1e-6, use_broadcast=False)

# Test scalar case
scalar = 1.0
for _, dst_shape in shapes:
for hybridize in [True, False]:
test_scalar_broadcast_to = TestScalarBroadcastTo(scalar, dst_shape)
expected_ret = _np.broadcast_to(scalar, dst_shape)
with mx.autograd.record():
# `np.empty(())` serves as a dummpy input
ret = test_scalar_broadcast_to(np.empty(()))
assert_almost_equal(ret.asnumpy(), expected_ret, rtol=1e-5, atol=1e-6, use_broadcast=False)


@with_seed()
@use_np
Expand Down

0 comments on commit f17d19b

Please sign in to comment.