From 9a42e15ef1618315ca319dea79c3a2ba50ea06db Mon Sep 17 00:00:00 2001 From: Minjie Wang Date: Sun, 12 Jun 2016 22:48:55 -0400 Subject: [PATCH] add ufunc_helper that helps all broadcasting issues; fix div operator in python3 --- python/mxnet/ndarray.py | 183 +++++++++++++++++++--------------------- 1 file changed, 85 insertions(+), 98 deletions(-) diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index 9f296fd12677..7ab7b185271f 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -2,6 +2,7 @@ # pylint: disable= too-many-lines, redefined-builtin """NDArray API of mxnet.""" from __future__ import absolute_import +from __future__ import division import ctypes import warnings @@ -144,13 +145,13 @@ def __imul__(self, other): def __rmul__(self, other): return self.__mul__(other) - def __div__(self, other): + def __truediv__(self, other): return divide(self, other) - def __rdiv__(self, other): + def __rtruediv__(self, other): return divide(other, self) - def __idiv__(self, other): + def __itruediv__(self, other): if not self.writable: raise ValueError('trying to divide from a readonly NDArray') if isinstance(other, NDArray): @@ -166,9 +167,6 @@ def __pow__(self, other): def __rpow__(self, other): return power(other, self) - def __truediv__(self, other): - return self.__div__(other) - def __getstate__(self): this = self.__dict__.copy() handle = this['handle'] @@ -541,42 +539,85 @@ def empty(shape, ctx=None, dtype=mx_real_t): ctx = Context.default_ctx return NDArray(handle=_new_alloc_handle(shape, ctx, False, dtype)) -def add(lhs, rhs): - """ Perform element-wise addition +def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None): + """ Helper function for element-wise operation + The function will perform numpy-like broadcasting if needed and call different functions Parameters ---------- - lhs : Array or float value - left hand side operand + lhs : NDArray or numeric value + left hande side operand - rhs : Array of float value + rhs : NDArray or numeric value right hand side operand + fn_array : function + function to be called if both lhs and rhs are of NDArray type + + fn_scalar : function + function to be called if both lhs and rhs are numeric values + + lfn_scalar : function + function to be called if lhs is NDArray while rhs is numeric value + + rfn_scalar : function + function to be called if lhs is numeric value while rhs is NDArray; + if none is provided, then the function is commutative, so rfn_scalar is equal to lfn_scalar + Returns ------- - out: Array + out: NDArray result array """ # pylint: disable= no-member, protected-access if isinstance(lhs, numeric_types): if isinstance(rhs, numeric_types): - return lhs + rhs + return fn_scalar(lhs, rhs) else: - return add(rhs, lhs) + if rfn_scalar is None: + # commutative function + return lfn_scalar(rhs, float(lhs)) + else: + return rfn_scalar(rhs, float(lhs)) elif isinstance(rhs, numeric_types): - return NDArray._plus_scalar(lhs, float(rhs)) + return lfn_scalar(lhs, float(rhs)) elif isinstance(rhs, NDArray): + # check whether broadcasting is needed lsize = functools.reduce(operator.mul, lhs.shape) rsize = functools.reduce(operator.mul, rhs.shape) if lsize < rsize: lhs = lhs.broadcast_to(rhs.shape) elif lsize > rsize: rhs = rhs.broadcast_to(lhs.shape) - return NDArray._plus(lhs, rhs) + return fn_array(lhs, rhs) else: raise TypeError('type %s not supported' % str(type(rhs))) # pylint: enable= no-member, protected-access +def add(lhs, rhs): + """ Perform element-wise addition + + Parameters + ---------- + lhs : Array or float value + left hand side operand + + rhs : Array of float value + right hand side operand + + Returns + ------- + out: Array + result array + """ + return _ufunc_helper( + lhs, + rhs, + NDArray._plus, + operator.add, + NDArray._plus_scalar, + None) + def subtract(lhs, rhs): """ Perform element-wise subtract @@ -593,27 +634,13 @@ def subtract(lhs, rhs): out: Array result array """ - # pylint: disable= no-member, protected-access - if isinstance(lhs, numeric_types): - if isinstance(rhs, numeric_types): - return lhs - rhs - elif isinstance(rhs, NDArray): - return NDArray._rminus_scalar(rhs, float(lhs)) - else: - raise TypeError('type %s not supported' % str(type(rhs))) - elif isinstance(rhs, numeric_types): - return NDArray._minus_scalar(lhs, float(rhs)) - elif isinstance(rhs, NDArray): - lsize = functools.reduce(operator.mul, lhs.shape) - rsize = functools.reduce(operator.mul, rhs.shape) - if lsize < rsize: - lhs = lhs.broadcast_to(rhs.shape) - elif lsize > rsize: - rhs = rhs.broadcast_to(lhs.shape) - return NDArray._minus(lhs, rhs) - else: - raise TypeError('type %s not supported' % str(type(rhs))) - # pylint: enable= no-member, protected-access + return _ufunc_helper( + lhs, + rhs, + NDArray._minus, + operator.sub, + NDArray._minus_scalar, + NDArray._rminus_scalar) def multiply(lhs, rhs): """ Perform element-wise multiplication @@ -631,25 +658,13 @@ def multiply(lhs, rhs): out: Array result array """ - # pylint: disable= no-member, protected-access - if isinstance(lhs, numeric_types): - if isinstance(rhs, numeric_types): - return lhs * rhs - else: - return multiply(rhs, lhs) - elif isinstance(rhs, numeric_types): - return NDArray._mul_scalar(lhs, float(rhs)) - elif isinstance(rhs, NDArray): - lsize = functools.reduce(operator.mul, lhs.shape) - rsize = functools.reduce(operator.mul, rhs.shape) - if lsize < rsize: - lhs = lhs.broadcast_to(rhs.shape) - elif lsize > rsize: - rhs = rhs.broadcast_to(lhs.shape) - return NDArray._mul(lhs, rhs) - else: - raise TypeError('type %s not supported' % str(type(rhs))) - # pylint: enable= no-member, protected-access + return _ufunc_helper( + lhs, + rhs, + NDArray._mul, + operator.mul, + NDArray._mul_scalar, + None) def divide(lhs, rhs): """ Perform element-wise divide @@ -667,27 +682,13 @@ def divide(lhs, rhs): out: Array result array """ - # pylint: disable= no-member, protected-access - if isinstance(lhs, numeric_types): - if isinstance(rhs, numeric_types): - return lhs / rhs - elif isinstance(rhs, NDArray): - return NDArray._rdiv_scalar(rhs, float(lhs)) - else: - raise TypeError('type %s not supported' % str(type(rhs))) - elif isinstance(rhs, numeric_types): - return NDArray._div_scalar(lhs, float(rhs)) - elif isinstance(rhs, NDArray): - lsize = functools.reduce(operator.mul, lhs.shape) - rsize = functools.reduce(operator.mul, rhs.shape) - if lsize < rsize: - lhs = lhs.broadcast_to(rhs.shape) - elif lsize > rsize: - rhs = rhs.broadcast_to(lhs.shape) - return NDArray._div(lhs, rhs) - else: - raise TypeError('type %s not supported' % str(type(rhs))) - # pylint: enable= no-member, protected-access + return _ufunc_helper( + lhs, + rhs, + NDArray._div, + operator.truediv, + NDArray._div_scalar, + NDArray._rdiv_scalar) def power(lhs, rhs): """ Perform power operator @@ -705,27 +706,13 @@ def power(lhs, rhs): out: Array result array """ - # pylint: disable= no-member, protected-access - if isinstance(lhs, numeric_types): - if isinstance(rhs, numeric_types): - return lhs ** rhs - elif isinstance(rhs, NDArray): - return NDArray._rpower_scalar(rhs, float(lhs)) - else: - raise TypeError('type %s not supported' % str(type(rhs))) - elif isinstance(rhs, numeric_types): - return NDArray._power_scalar(lhs, float(rhs)) - elif isinstance(rhs, NDArray): - lsize = functools.reduce(operator.mul, lhs.shape) - rsize = functools.reduce(operator.mul, rhs.shape) - if lsize < rsize: - lhs = lhs.broadcast_to(rhs.shape) - elif lsize > rsize: - rhs = rhs.broadcast_to(lhs.shape) - return NDArray._power(lhs, rhs) - else: - raise TypeError('type %s not supported' % str(type(rhs))) - # pylint: enable= no-member, protected-access + return _ufunc_helper( + lhs, + rhs, + NDArray._power, + operator.pow, + NDArray._power_scalar, + NDArray._rpower_scalar) def true_divide(lhs, rhs): """ Same as numpy's true_divide. It adjusts the output type to present the best answer,