Skip to content

Commit 6abb673

Browse files
committed
Modify return type
1 parent 148e6cc commit 6abb673

File tree

2 files changed

+31
-50
lines changed

2 files changed

+31
-50
lines changed

vm/src/macros.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ macro_rules! class_or_notimplemented {
221221
($vm:expr, $t:ty, $obj:expr) => {
222222
match $crate::pyobject::PyObject::downcast::<$t>($obj) {
223223
Ok(pyref) => pyref,
224-
Err(_) => return Ok($vm.ctx.not_implemented()),
224+
Err(_) => return Ok(PyArithmaticValue::NotImplemented),
225225
}
226226
};
227227
}

vm/src/stdlib/array.rs

Lines changed: 30 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ use crate::obj::objstr::PyStringRef;
55
use crate::obj::objtype::PyClassRef;
66
use crate::obj::{objbool, objiter};
77
use crate::pyobject::{
8-
BorrowValue, Either, IntoPyObject, PyClassImpl, PyIterable, PyObjectRef, PyRef, PyResult,
9-
PyValue, TryFromObject,
8+
Either, IntoPyObject, PyClassImpl, PyIterable, PyObjectRef, PyRef, PyResult, PyValue,
9+
TryFromObject, PyComparisonValue, PyArithmaticValue,
1010
};
1111
use crate::VirtualMachine;
1212

@@ -438,113 +438,94 @@ impl PyArray {
438438
}
439439

440440
#[pymethod(name = "__eq__")]
441-
fn eq(lhs: PyObjectRef, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult {
442-
let lhs = class_or_notimplemented!(vm, Self, lhs);
443-
let rhs = class_or_notimplemented!(vm, Self, rhs);
444-
let lhs = lhs.borrow_value();
441+
fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyComparisonValue> {
442+
let lhs = self.borrow_value();
443+
let rhs = class_or_notimplemented!(vm, Self, other);
445444
let rhs = rhs.borrow_value();
446445
if lhs.len() != rhs.len() {
447-
Ok(vm.ctx.new_bool(false))
446+
Ok(PyArithmaticValue::Implemented(false))
448447
} else {
449448
for (a, b) in lhs.iter(vm).zip(rhs.iter(vm)) {
450449
let ne = objbool::boolval(vm, vm._ne(a, b)?)?;
451450
if ne {
452-
return Ok(vm.ctx.new_bool(false));
451+
return Ok(PyArithmaticValue::Implemented(false))
453452
}
454453
}
455-
Ok(vm.ctx.new_bool(true))
454+
Ok(PyArithmaticValue::Implemented(true))
456455
}
457456
}
458457

459458
#[pymethod(name = "__ne__")]
460-
fn ne(lhs: PyObjectRef, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult {
461-
let lhs = class_or_notimplemented!(vm, Self, lhs);
462-
let rhs = class_or_notimplemented!(vm, Self, rhs);
463-
let lhs = lhs.borrow_value();
464-
let rhs = rhs.borrow_value();
465-
if lhs.len() != rhs.len() {
466-
Ok(vm.new_bool(true))
467-
} else {
468-
for (a, b) in lhs.iter(vm).zip(rhs.iter(vm)) {
469-
let ne = objbool::boolval(vm, vm._ne(a?, b?)?)?;
470-
if ne {
471-
return Ok(vm.new_bool(true));
472-
}
473-
}
474-
Ok(vm.new_bool(false))
475-
}
459+
fn ne(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyComparisonValue> {
460+
Ok(self.eq(other, vm)?.map(|v| !v))
476461
}
477462

478463
#[pymethod(name = "__lt__")]
479-
fn lt(lhs: PyObjectRef, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult {
480-
let lhs = class_or_notimplemented!(vm, Self, lhs);
481-
let rhs = class_or_notimplemented!(vm, Self, rhs);
482-
let lhs = lhs.borrow_value();
464+
fn lt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyComparisonValue> {
465+
let lhs = self.borrow_value();
466+
let rhs = class_or_notimplemented!(vm, Self, other);
483467
let rhs = rhs.borrow_value();
484468

485469
for (a, b) in lhs.iter(vm).zip(rhs.iter(vm)) {
486470
let lt = objbool::boolval(vm, vm._lt(a, b)?)?;
487471

488472
if lt {
489-
return Ok(vm.ctx.new_bool(true));
473+
return Ok(PyArithmaticValue::Implemented(true));
490474
}
491475
}
492476

493-
Ok(vm.ctx.new_bool(lhs.len() < rhs.len()))
477+
Ok(PyArithmaticValue::Implemented(lhs.len() < rhs.len()))
494478
}
495479

496480
#[pymethod(name = "__le__")]
497-
fn le(lhs: PyObjectRef, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult {
498-
let lhs = class_or_notimplemented!(vm, Self, lhs);
499-
let rhs = class_or_notimplemented!(vm, Self, rhs);
500-
let lhs = lhs.borrow_value();
481+
fn le(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyComparisonValue> {
482+
let lhs = self.borrow_value();
483+
let rhs = class_or_notimplemented!(vm, Self, other);
501484
let rhs = rhs.borrow_value();
502485

503486
for (a, b) in lhs.iter(vm).zip(rhs.iter(vm)) {
504487
let le = objbool::boolval(vm, vm._le(a, b)?)?;
505488

506489
if le {
507-
return Ok(vm.ctx.new_bool(true));
490+
return Ok(PyArithmaticValue::Implemented(true));
508491
}
509492
}
510493

511-
Ok(vm.ctx.new_bool(lhs.len() <= rhs.len()))
494+
Ok(PyArithmaticValue::Implemented(lhs.len() <= rhs.len()))
512495
}
513496

514497
#[pymethod(name = "__gt__")]
515-
fn gt(lhs: PyObjectRef, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult {
516-
let lhs = class_or_notimplemented!(vm, Self, lhs);
517-
let rhs = class_or_notimplemented!(vm, Self, rhs);
518-
let lhs = lhs.borrow_value();
498+
fn gt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyComparisonValue> {
499+
let lhs = self.borrow_value();
500+
let rhs = class_or_notimplemented!(vm, Self, other);
519501
let rhs = rhs.borrow_value();
520502

521503
for (a, b) in lhs.iter(vm).zip(rhs.iter(vm)) {
522504
let gt = objbool::boolval(vm, vm._gt(a, b)?)?;
523505

524506
if gt {
525-
return Ok(vm.ctx.new_bool(true));
507+
return Ok(PyArithmaticValue::Implemented(true));
526508
}
527509
}
528510

529-
Ok(vm.ctx.new_bool(lhs.len() > rhs.len()))
511+
Ok(PyArithmaticValue::Implemented(lhs.len() > rhs.len()))
530512
}
531513

532514
#[pymethod(name = "__ge__")]
533-
fn ge(lhs: PyObjectRef, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult {
534-
let lhs = class_or_notimplemented!(vm, Self, lhs);
535-
let rhs = class_or_notimplemented!(vm, Self, rhs);
536-
let lhs = lhs.borrow_value();
515+
fn ge(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyComparisonValue> {
516+
let lhs = self.borrow_value();
517+
let rhs = class_or_notimplemented!(vm, Self, other);
537518
let rhs = rhs.borrow_value();
538519

539520
for (a, b) in lhs.iter(vm).zip(rhs.iter(vm)) {
540521
let ge = objbool::boolval(vm, vm._ge(a, b)?)?;
541522

542523
if ge {
543-
return Ok(vm.ctx.new_bool(true));
524+
return Ok(PyArithmaticValue::Implemented(true));
544525
}
545526
}
546527

547-
Ok(vm.ctx.new_bool(lhs.len() >= rhs.len()))
528+
Ok(PyArithmaticValue::Implemented(lhs.len() >= rhs.len()))
548529
}
549530

550531
#[pymethod(name = "__len__")]

0 commit comments

Comments
 (0)