Skip to content

Commit a35408a

Browse files
authored
Merge pull request RustPython#2028 from ugaemi/master
Implement __ne__ methods for array, bytearray
2 parents b57d943 + 1d554eb commit a35408a

File tree

6 files changed

+57
-34
lines changed

6 files changed

+57
-34
lines changed

tests/snippets/bytearray.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,3 +713,10 @@
713713
# mod
714714
assert bytearray('rust%bpython%b', 'utf-8') % (b' ', b'!') == bytearray(b'rust python!')
715715
assert bytearray('x=%i y=%f', 'utf-8') % (1, 2.5) == bytearray(b'x=1 y=2.500000')
716+
717+
# eq, ne
718+
a = bytearray(b'hello, world')
719+
b = a.copy()
720+
assert a.__ne__(b) is False
721+
b = bytearray(b'my bytearray')
722+
assert a.__ne__(b) is True

tests/snippets/stdlib_array.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,10 @@
1313
a1.extend([4, 5, 6, 7])
1414

1515
assert a1 == array("h", [3, 2, 1, 0, 4, 5, 6, 7])
16+
17+
# eq, ne
18+
a = array("b", [0, 1, 2, 3])
19+
b = a
20+
assert a.__ne__(b) is False
21+
b = array("B", [3, 2, 1, 0])
22+
assert a.__ne__(b) is True

vm/src/bytesinner.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,10 @@ impl PyBytesInner {
274274
self.cmp(other, |a, b| a == b, vm)
275275
}
276276

277+
pub fn ne(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue {
278+
self.eq(other, vm).map(|v| !v)
279+
}
280+
277281
pub fn ge(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue {
278282
self.cmp(other, |a, b| a >= b, vm)
279283
}

vm/src/macros.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ macro_rules! class_or_notimplemented {
221221
($vm:expr, $t:ty, $obj:expr) => {
222222
match $crate::pyobject::PyObject::downcast::<$t>($obj) {
223223
Ok(pyref) => pyref,
224-
Err(_) => return Ok($vm.ctx.not_implemented()),
224+
Err(_) => return Ok(PyArithmaticValue::NotImplemented),
225225
}
226226
};
227227
}

vm/src/obj/objbytearray.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@ impl PyByteArray {
123123
self.borrow_value().eq(other, vm)
124124
}
125125

126+
#[pymethod(name = "__ne__")]
127+
fn ne(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue {
128+
self.borrow_value().ne(other, vm)
129+
}
130+
126131
#[pymethod(name = "__ge__")]
127132
fn ge(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue {
128133
self.borrow_value().ge(other, vm)

vm/src/stdlib/array.rs

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ use crate::obj::objstr::PyStringRef;
55
use crate::obj::objtype::PyClassRef;
66
use crate::obj::{objbool, objiter};
77
use crate::pyobject::{
8-
BorrowValue, Either, IntoPyObject, PyClassImpl, PyIterable, PyObjectRef, PyRef, PyResult,
9-
PyValue, TryFromObject,
8+
BorrowValue, Either, IntoPyObject, PyArithmaticValue, PyClassImpl, PyComparisonValue,
9+
PyIterable, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject,
1010
};
1111
use crate::VirtualMachine;
1212

@@ -438,94 +438,94 @@ impl PyArray {
438438
}
439439

440440
#[pymethod(name = "__eq__")]
441-
fn eq(lhs: PyObjectRef, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult {
442-
let lhs = class_or_notimplemented!(vm, Self, lhs);
443-
let rhs = class_or_notimplemented!(vm, Self, rhs);
444-
let lhs = lhs.borrow_value();
441+
fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyComparisonValue> {
442+
let lhs = self.borrow_value();
443+
let rhs = class_or_notimplemented!(vm, Self, other);
445444
let rhs = rhs.borrow_value();
446445
if lhs.len() != rhs.len() {
447-
Ok(vm.ctx.new_bool(false))
446+
Ok(PyArithmaticValue::Implemented(false))
448447
} else {
449448
for (a, b) in lhs.iter(vm).zip(rhs.iter(vm)) {
450449
let ne = objbool::boolval(vm, vm._ne(a, b)?)?;
451450
if ne {
452-
return Ok(vm.ctx.new_bool(false));
451+
return Ok(PyArithmaticValue::Implemented(false));
453452
}
454453
}
455-
Ok(vm.ctx.new_bool(true))
454+
Ok(PyArithmaticValue::Implemented(true))
456455
}
457456
}
458457

458+
#[pymethod(name = "__ne__")]
459+
fn ne(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyComparisonValue> {
460+
Ok(self.eq(other, vm)?.map(|v| !v))
461+
}
462+
459463
#[pymethod(name = "__lt__")]
460-
fn lt(lhs: PyObjectRef, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult {
461-
let lhs = class_or_notimplemented!(vm, Self, lhs);
462-
let rhs = class_or_notimplemented!(vm, Self, rhs);
463-
let lhs = lhs.borrow_value();
464+
fn lt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyComparisonValue> {
465+
let lhs = self.borrow_value();
466+
let rhs = class_or_notimplemented!(vm, Self, other);
464467
let rhs = rhs.borrow_value();
465468

466469
for (a, b) in lhs.iter(vm).zip(rhs.iter(vm)) {
467470
let lt = objbool::boolval(vm, vm._lt(a, b)?)?;
468471

469472
if lt {
470-
return Ok(vm.ctx.new_bool(true));
473+
return Ok(PyArithmaticValue::Implemented(true));
471474
}
472475
}
473476

474-
Ok(vm.ctx.new_bool(lhs.len() < rhs.len()))
477+
Ok(PyArithmaticValue::Implemented(lhs.len() < rhs.len()))
475478
}
476479

477480
#[pymethod(name = "__le__")]
478-
fn le(lhs: PyObjectRef, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult {
479-
let lhs = class_or_notimplemented!(vm, Self, lhs);
480-
let rhs = class_or_notimplemented!(vm, Self, rhs);
481-
let lhs = lhs.borrow_value();
481+
fn le(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyComparisonValue> {
482+
let lhs = self.borrow_value();
483+
let rhs = class_or_notimplemented!(vm, Self, other);
482484
let rhs = rhs.borrow_value();
483485

484486
for (a, b) in lhs.iter(vm).zip(rhs.iter(vm)) {
485487
let le = objbool::boolval(vm, vm._le(a, b)?)?;
486488

487489
if le {
488-
return Ok(vm.ctx.new_bool(true));
490+
return Ok(PyArithmaticValue::Implemented(true));
489491
}
490492
}
491493

492-
Ok(vm.ctx.new_bool(lhs.len() <= rhs.len()))
494+
Ok(PyArithmaticValue::Implemented(lhs.len() <= rhs.len()))
493495
}
494496

495497
#[pymethod(name = "__gt__")]
496-
fn gt(lhs: PyObjectRef, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult {
497-
let lhs = class_or_notimplemented!(vm, Self, lhs);
498-
let rhs = class_or_notimplemented!(vm, Self, rhs);
499-
let lhs = lhs.borrow_value();
498+
fn gt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyComparisonValue> {
499+
let lhs = self.borrow_value();
500+
let rhs = class_or_notimplemented!(vm, Self, other);
500501
let rhs = rhs.borrow_value();
501502

502503
for (a, b) in lhs.iter(vm).zip(rhs.iter(vm)) {
503504
let gt = objbool::boolval(vm, vm._gt(a, b)?)?;
504505

505506
if gt {
506-
return Ok(vm.ctx.new_bool(true));
507+
return Ok(PyArithmaticValue::Implemented(true));
507508
}
508509
}
509510

510-
Ok(vm.ctx.new_bool(lhs.len() > rhs.len()))
511+
Ok(PyArithmaticValue::Implemented(lhs.len() > rhs.len()))
511512
}
512513

513514
#[pymethod(name = "__ge__")]
514-
fn ge(lhs: PyObjectRef, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult {
515-
let lhs = class_or_notimplemented!(vm, Self, lhs);
516-
let rhs = class_or_notimplemented!(vm, Self, rhs);
517-
let lhs = lhs.borrow_value();
515+
fn ge(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyComparisonValue> {
516+
let lhs = self.borrow_value();
517+
let rhs = class_or_notimplemented!(vm, Self, other);
518518
let rhs = rhs.borrow_value();
519519

520520
for (a, b) in lhs.iter(vm).zip(rhs.iter(vm)) {
521521
let ge = objbool::boolval(vm, vm._ge(a, b)?)?;
522522

523523
if ge {
524-
return Ok(vm.ctx.new_bool(true));
524+
return Ok(PyArithmaticValue::Implemented(true));
525525
}
526526
}
527527

528-
Ok(vm.ctx.new_bool(lhs.len() >= rhs.len()))
528+
Ok(PyArithmaticValue::Implemented(lhs.len() >= rhs.len()))
529529
}
530530

531531
#[pymethod(name = "__len__")]

0 commit comments

Comments
 (0)