Skip to content

Commit

Permalink
add ufunc_helper that helps all broadcasting issues; fix div operator…
Browse files Browse the repository at this point in the history
… in python3
  • Loading branch information
jermainewang committed Jun 13, 2016
1 parent d7fdf40 commit 9a42e15
Showing 1 changed file with 85 additions and 98 deletions.
183 changes: 85 additions & 98 deletions python/mxnet/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# pylint: disable= too-many-lines, redefined-builtin
"""NDArray API of mxnet."""
from __future__ import absolute_import
from __future__ import division

import ctypes
import warnings
Expand Down Expand Up @@ -144,13 +145,13 @@ def __imul__(self, other):
def __rmul__(self, other):
return self.__mul__(other)

def __div__(self, other):
def __truediv__(self, other):
return divide(self, other)

def __rdiv__(self, other):
def __rtruediv__(self, other):
return divide(other, self)

def __idiv__(self, other):
def __itruediv__(self, other):
if not self.writable:
raise ValueError('trying to divide from a readonly NDArray')
if isinstance(other, NDArray):
Expand All @@ -166,9 +167,6 @@ def __pow__(self, other):
def __rpow__(self, other):
return power(other, self)

def __truediv__(self, other):
return self.__div__(other)

def __getstate__(self):
this = self.__dict__.copy()
handle = this['handle']
Expand Down Expand Up @@ -541,42 +539,85 @@ def empty(shape, ctx=None, dtype=mx_real_t):
ctx = Context.default_ctx
return NDArray(handle=_new_alloc_handle(shape, ctx, False, dtype))

def add(lhs, rhs):
""" Perform element-wise addition
def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None):
""" Helper function for element-wise operation
The function will perform numpy-like broadcasting if needed and call different functions
Parameters
----------
lhs : Array or float value
left hand side operand
lhs : NDArray or numeric value
left hande side operand
rhs : Array of float value
rhs : NDArray or numeric value
right hand side operand
fn_array : function
function to be called if both lhs and rhs are of NDArray type
fn_scalar : function
function to be called if both lhs and rhs are numeric values
lfn_scalar : function
function to be called if lhs is NDArray while rhs is numeric value
rfn_scalar : function
function to be called if lhs is numeric value while rhs is NDArray;
if none is provided, then the function is commutative, so rfn_scalar is equal to lfn_scalar
Returns
-------
out: Array
out: NDArray
result array
"""
# pylint: disable= no-member, protected-access
if isinstance(lhs, numeric_types):
if isinstance(rhs, numeric_types):
return lhs + rhs
return fn_scalar(lhs, rhs)
else:
return add(rhs, lhs)
if rfn_scalar is None:
# commutative function
return lfn_scalar(rhs, float(lhs))
else:
return rfn_scalar(rhs, float(lhs))
elif isinstance(rhs, numeric_types):
return NDArray._plus_scalar(lhs, float(rhs))
return lfn_scalar(lhs, float(rhs))
elif isinstance(rhs, NDArray):
# check whether broadcasting is needed
lsize = functools.reduce(operator.mul, lhs.shape)
rsize = functools.reduce(operator.mul, rhs.shape)
if lsize < rsize:
lhs = lhs.broadcast_to(rhs.shape)
elif lsize > rsize:
rhs = rhs.broadcast_to(lhs.shape)
return NDArray._plus(lhs, rhs)
return fn_array(lhs, rhs)
else:
raise TypeError('type %s not supported' % str(type(rhs)))
# pylint: enable= no-member, protected-access

def add(lhs, rhs):
""" Perform element-wise addition
Parameters
----------
lhs : Array or float value
left hand side operand
rhs : Array of float value
right hand side operand
Returns
-------
out: Array
result array
"""
return _ufunc_helper(
lhs,
rhs,
NDArray._plus,
operator.add,
NDArray._plus_scalar,
None)

def subtract(lhs, rhs):
""" Perform element-wise subtract
Expand All @@ -593,27 +634,13 @@ def subtract(lhs, rhs):
out: Array
result array
"""
# pylint: disable= no-member, protected-access
if isinstance(lhs, numeric_types):
if isinstance(rhs, numeric_types):
return lhs - rhs
elif isinstance(rhs, NDArray):
return NDArray._rminus_scalar(rhs, float(lhs))
else:
raise TypeError('type %s not supported' % str(type(rhs)))
elif isinstance(rhs, numeric_types):
return NDArray._minus_scalar(lhs, float(rhs))
elif isinstance(rhs, NDArray):
lsize = functools.reduce(operator.mul, lhs.shape)
rsize = functools.reduce(operator.mul, rhs.shape)
if lsize < rsize:
lhs = lhs.broadcast_to(rhs.shape)
elif lsize > rsize:
rhs = rhs.broadcast_to(lhs.shape)
return NDArray._minus(lhs, rhs)
else:
raise TypeError('type %s not supported' % str(type(rhs)))
# pylint: enable= no-member, protected-access
return _ufunc_helper(
lhs,
rhs,
NDArray._minus,
operator.sub,
NDArray._minus_scalar,
NDArray._rminus_scalar)

def multiply(lhs, rhs):
""" Perform element-wise multiplication
Expand All @@ -631,25 +658,13 @@ def multiply(lhs, rhs):
out: Array
result array
"""
# pylint: disable= no-member, protected-access
if isinstance(lhs, numeric_types):
if isinstance(rhs, numeric_types):
return lhs * rhs
else:
return multiply(rhs, lhs)
elif isinstance(rhs, numeric_types):
return NDArray._mul_scalar(lhs, float(rhs))
elif isinstance(rhs, NDArray):
lsize = functools.reduce(operator.mul, lhs.shape)
rsize = functools.reduce(operator.mul, rhs.shape)
if lsize < rsize:
lhs = lhs.broadcast_to(rhs.shape)
elif lsize > rsize:
rhs = rhs.broadcast_to(lhs.shape)
return NDArray._mul(lhs, rhs)
else:
raise TypeError('type %s not supported' % str(type(rhs)))
# pylint: enable= no-member, protected-access
return _ufunc_helper(
lhs,
rhs,
NDArray._mul,
operator.mul,
NDArray._mul_scalar,
None)

def divide(lhs, rhs):
""" Perform element-wise divide
Expand All @@ -667,27 +682,13 @@ def divide(lhs, rhs):
out: Array
result array
"""
# pylint: disable= no-member, protected-access
if isinstance(lhs, numeric_types):
if isinstance(rhs, numeric_types):
return lhs / rhs
elif isinstance(rhs, NDArray):
return NDArray._rdiv_scalar(rhs, float(lhs))
else:
raise TypeError('type %s not supported' % str(type(rhs)))
elif isinstance(rhs, numeric_types):
return NDArray._div_scalar(lhs, float(rhs))
elif isinstance(rhs, NDArray):
lsize = functools.reduce(operator.mul, lhs.shape)
rsize = functools.reduce(operator.mul, rhs.shape)
if lsize < rsize:
lhs = lhs.broadcast_to(rhs.shape)
elif lsize > rsize:
rhs = rhs.broadcast_to(lhs.shape)
return NDArray._div(lhs, rhs)
else:
raise TypeError('type %s not supported' % str(type(rhs)))
# pylint: enable= no-member, protected-access
return _ufunc_helper(
lhs,
rhs,
NDArray._div,
operator.truediv,
NDArray._div_scalar,
NDArray._rdiv_scalar)

def power(lhs, rhs):
""" Perform power operator
Expand All @@ -705,27 +706,13 @@ def power(lhs, rhs):
out: Array
result array
"""
# pylint: disable= no-member, protected-access
if isinstance(lhs, numeric_types):
if isinstance(rhs, numeric_types):
return lhs ** rhs
elif isinstance(rhs, NDArray):
return NDArray._rpower_scalar(rhs, float(lhs))
else:
raise TypeError('type %s not supported' % str(type(rhs)))
elif isinstance(rhs, numeric_types):
return NDArray._power_scalar(lhs, float(rhs))
elif isinstance(rhs, NDArray):
lsize = functools.reduce(operator.mul, lhs.shape)
rsize = functools.reduce(operator.mul, rhs.shape)
if lsize < rsize:
lhs = lhs.broadcast_to(rhs.shape)
elif lsize > rsize:
rhs = rhs.broadcast_to(lhs.shape)
return NDArray._power(lhs, rhs)
else:
raise TypeError('type %s not supported' % str(type(rhs)))
# pylint: enable= no-member, protected-access
return _ufunc_helper(
lhs,
rhs,
NDArray._power,
operator.pow,
NDArray._power_scalar,
NDArray._rpower_scalar)

def true_divide(lhs, rhs):
""" Same as numpy's true_divide. It adjusts the output type to present the best answer,
Expand Down

0 comments on commit 9a42e15

Please sign in to comment.