Skip to content

Commit

Permalink
Add simple equal operator (apache#4027)
Browse files Browse the repository at this point in the history
* add simple equal operator

* __eq__ override in python side; handle scalar situation; unittests
  • Loading branch information
jermainewang authored and piiswrong committed Dec 29, 2016
1 parent aa4c477 commit 7ba8307
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 0 deletions.
29 changes: 29 additions & 0 deletions python/mxnet/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ def __pow__(self, other):
def __rpow__(self, other):
return power(other, self)

def __eq__(self, other):
return equal(self, other)

def __getstate__(self):
handle = self.handle
this = {'handle' : None}
Expand Down Expand Up @@ -849,6 +852,32 @@ def minimum(lhs, rhs):
None)
# pylint: enable= no-member, protected-access

def equal(lhs, rhs):
"""Return (lhs == rhs) element-wise.
Parameters
----------
lhs : Array or float value
left hand side operand
rhs : Array of float value
right hand side operand
Returns
-------
out: Array
result array
"""
# pylint: disable= no-member, protected-access
return _ufunc_helper(
lhs,
rhs,
_internal._equal,
lambda x, y: 1 if x == y else 0,
_internal._equal_scalar,
None)
# pylint: enable= no-member, protected-access

def true_divide(lhs, rhs):
""" Same as numpy's true_divide. It adjusts the output type to present the best answer,
regardless of input types.
Expand Down
4 changes: 4 additions & 0 deletions src/operator/tensor/elemwise_binary_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ NNVM_REGISTER_OP(_backward_minimum)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastBackwardUseIn<cpu, mshadow_op::le,
mshadow_op::gt>);

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_equal)
.add_alias("broadcast_equal").add_alias("_Equal")
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::eq>);

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_hypot)
.add_alias("broadcast_hypot").add_alias("_Hypot")
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::hypot>)
Expand Down
3 changes: 3 additions & 0 deletions src/operator/tensor/elemwise_binary_broadcast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ NNVM_REGISTER_OP(_backward_minimum)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, mshadow_op::le,
mshadow_op::gt>);

NNVM_REGISTER_OP(_equal)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::eq>);

NNVM_REGISTER_OP(_hypot)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::hypot>);

Expand Down
4 changes: 4 additions & 0 deletions src/operator/tensor/elemwise_binary_scalar_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_minimum_scalar)
.set_attr_parser([](NodeAttrs* attrs) {attrs->parsed = std::stod(attrs->dict["scalar"]);})
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarBackward<cpu, mshadow_op::le>);

MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_equal_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarCompute<cpu, mshadow_op::eq>)
.add_alias("_EqualScalar");

MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_power_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarCompute<cpu, mshadow_op::power>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_power_scalar"})
Expand Down
3 changes: 3 additions & 0 deletions src/operator/tensor/elemwise_binary_scalar_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ NNVM_REGISTER_OP(_minimum_scalar)
NNVM_REGISTER_OP(_backward_minimum_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarBackward<gpu, mshadow_op::le>);

NNVM_REGISTER_OP(_equal_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarCompute<gpu, mshadow_op::eq>);

NNVM_REGISTER_OP(_power_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarCompute<gpu, mshadow_op::power>);

Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,14 @@ def test_broadcast_to():
assert err < 1E-8
test_broadcast_to()

def test_ndarray_equal():
x = mx.nd.zeros((2, 3))
y = mx.nd.ones((2, 3))
z = x == y
assert (z.asnumpy() == np.zeros((2, 3))).all()
z = 0 == x
assert (z.asnumpy() == np.ones((2, 3))).all()

if __name__ == '__main__':
test_ndarray_setitem()
test_ndarray_crop()
Expand All @@ -356,3 +364,4 @@ def test_broadcast_to():
test_ndarray_onehot()
test_ndarray_fill()
test_reduce()
test_ndarray_equal()

0 comments on commit 7ba8307

Please sign in to comment.