Skip to content

Commit c7fd54e

Browse files
committed
Convert iterators to pyclass macros
1 parent 1f002e8 commit c7fd54e

File tree

11 files changed

+125
-132
lines changed

11 files changed

+125
-132
lines changed

vm/src/obj/objbytearray.rs

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::ops::{Deref, DerefMut};
77
use num_traits::ToPrimitive;
88

99
use crate::function::OptionalArg;
10-
use crate::pyobject::{PyContext, PyObjectRef, PyRef, PyResult, PyValue};
10+
use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue};
1111
use crate::vm::VirtualMachine;
1212

1313
use super::objint;
@@ -84,11 +84,7 @@ pub fn init(context: &PyContext) {
8484
"upper" => context.new_rustfunc(PyByteArrayRef::upper)
8585
});
8686

87-
let bytearrayiterator_type = &context.bytearrayiterator_type;
88-
extend_class!(context, bytearrayiterator_type, {
89-
"__next__" => context.new_rustfunc(PyByteArrayIteratorRef::next),
90-
"__iter__" => context.new_rustfunc(PyByteArrayIteratorRef::iter),
91-
});
87+
PyByteArrayIterator::extend_class(context, &context.bytearrayiterator_type);
9288
}
9389

9490
fn bytearray_new(
@@ -287,6 +283,7 @@ mod tests {
287283
}
288284
}
289285

286+
#[pyclass]
290287
#[derive(Debug)]
291288
pub struct PyByteArrayIterator {
292289
position: Cell<usize>,
@@ -299,10 +296,10 @@ impl PyValue for PyByteArrayIterator {
299296
}
300297
}
301298

302-
type PyByteArrayIteratorRef = PyRef<PyByteArrayIterator>;
303-
304-
impl PyByteArrayIteratorRef {
305-
fn next(self, vm: &VirtualMachine) -> PyResult<u8> {
299+
#[pyimpl]
300+
impl PyByteArrayIterator {
301+
#[pymethod(name = "__next__")]
302+
fn next(&self, vm: &VirtualMachine) -> PyResult<u8> {
306303
if self.position.get() < self.bytearray.value.borrow().len() {
307304
let ret = self.bytearray.value.borrow()[self.position.get()];
308305
self.position.set(self.position.get() + 1);
@@ -312,7 +309,8 @@ impl PyByteArrayIteratorRef {
312309
}
313310
}
314311

315-
fn iter(self, _vm: &VirtualMachine) -> Self {
316-
self
312+
#[pymethod(name = "__iter__")]
313+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
314+
zelf
317315
}
318316
}

vm/src/obj/objbytes.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,7 @@ pub fn init(context: &PyContext) {
6464
extend_class!(context, bytes_type, {
6565
"fromhex" => context.new_rustfunc(PyBytesRef::fromhex),
6666
});
67-
let bytesiterator_type = &context.bytesiterator_type;
68-
extend_class!(context, bytesiterator_type, {
69-
"__next__" => context.new_rustfunc(PyBytesIteratorRef::next),
70-
"__iter__" => context.new_rustfunc(PyBytesIteratorRef::iter),
71-
});
67+
PyBytesIterator::extend_class(context, &context.bytesiterator_type);
7268
}
7369

7470
#[pyimpl]
@@ -271,6 +267,7 @@ impl PyBytesRef {
271267
}
272268
}
273269

270+
#[pyclass]
274271
#[derive(Debug)]
275272
pub struct PyBytesIterator {
276273
position: Cell<usize>,
@@ -283,10 +280,10 @@ impl PyValue for PyBytesIterator {
283280
}
284281
}
285282

286-
type PyBytesIteratorRef = PyRef<PyBytesIterator>;
287-
288-
impl PyBytesIteratorRef {
289-
fn next(self, vm: &VirtualMachine) -> PyResult<u8> {
283+
#[pyimpl]
284+
impl PyBytesIterator {
285+
#[pymethod(name = "__next__")]
286+
fn next(&self, vm: &VirtualMachine) -> PyResult<u8> {
290287
if self.position.get() < self.bytes.inner.len() {
291288
let ret = self.bytes[self.position.get()];
292289
self.position.set(self.position.get() + 1);
@@ -296,7 +293,8 @@ impl PyBytesIteratorRef {
296293
}
297294
}
298295

299-
fn iter(self, _vm: &VirtualMachine) -> Self {
300-
self
296+
#[pymethod(name = "__iter__")]
297+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
298+
zelf
301299
}
302300
}

vm/src/obj/objenumerate.rs

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@ use num_bigint::BigInt;
55
use num_traits::Zero;
66

77
use crate::function::OptionalArg;
8-
use crate::pyobject::{PyContext, PyObjectRef, PyRef, PyResult, PyValue};
8+
use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue};
99
use crate::vm::VirtualMachine;
1010

1111
use super::objint::PyIntRef;
1212
use super::objiter;
1313
use super::objtype::PyClassRef;
1414

15+
#[pyclass]
1516
#[derive(Debug)]
1617
pub struct PyEnumerate {
1718
counter: RefCell<BigInt>,
@@ -44,8 +45,10 @@ fn enumerate_new(
4445
.into_ref_with_type(vm, cls)
4546
}
4647

47-
impl PyEnumerateRef {
48-
fn next(self, vm: &VirtualMachine) -> PyResult {
48+
#[pyimpl]
49+
impl PyEnumerate {
50+
#[pymethod(name = "__next__")]
51+
fn next(&self, vm: &VirtualMachine) -> PyResult {
4952
let iterator = &self.iterator;
5053
let counter = &self.counter;
5154
let next_obj = objiter::call_next(vm, iterator)?;
@@ -58,16 +61,15 @@ impl PyEnumerateRef {
5861
Ok(result)
5962
}
6063

61-
fn iter(self, _vm: &VirtualMachine) -> Self {
62-
self
64+
#[pymethod(name = "__iter__")]
65+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
66+
zelf
6367
}
6468
}
6569

6670
pub fn init(context: &PyContext) {
67-
let enumerate_type = &context.enumerate_type;
68-
extend_class!(context, enumerate_type, {
71+
PyEnumerate::extend_class(context, &context.enumerate_type);
72+
extend_class!(context, &context.enumerate_type, {
6973
"__new__" => context.new_rustfunc(enumerate_new),
70-
"__next__" => context.new_rustfunc(PyEnumerateRef::next),
71-
"__iter__" => context.new_rustfunc(PyEnumerateRef::iter),
7274
});
7375
}

vm/src/obj/objfilter.rs

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::pyobject::{IdProtocol, PyContext, PyObjectRef, PyRef, PyResult, PyValue};
1+
use crate::pyobject::{IdProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue};
22
use crate::vm::VirtualMachine; // Required for arg_check! to use isinstance
33

44
use super::objbool;
@@ -7,6 +7,11 @@ use crate::obj::objtype::PyClassRef;
77

88
pub type PyFilterRef = PyRef<PyFilter>;
99

10+
/// filter(function or None, iterable) --> filter object
11+
///
12+
/// Return an iterator yielding those items of iterable for which function(item)
13+
/// is true. If function is None, return the items that are true.
14+
#[pyclass]
1015
#[derive(Debug)]
1116
pub struct PyFilter {
1217
predicate: PyObjectRef,
@@ -34,8 +39,10 @@ fn filter_new(
3439
.into_ref_with_type(vm, cls)
3540
}
3641

37-
impl PyFilterRef {
38-
fn next(self, vm: &VirtualMachine) -> PyResult {
42+
#[pyimpl]
43+
impl PyFilter {
44+
#[pymethod(name = "__next__")]
45+
fn next(&self, vm: &VirtualMachine) -> PyResult {
3946
let predicate = &self.predicate;
4047
let iterator = &self.iterator;
4148
loop {
@@ -53,23 +60,15 @@ impl PyFilterRef {
5360
}
5461
}
5562

56-
fn iter(self, _vm: &VirtualMachine) -> Self {
57-
self
63+
#[pymethod(name = "__iter__")]
64+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
65+
zelf
5866
}
5967
}
6068

6169
pub fn init(context: &PyContext) {
62-
let filter_type = &context.filter_type;
63-
64-
let filter_doc =
65-
"filter(function or None, iterable) --> filter object\n\n\
66-
Return an iterator yielding those items of iterable for which function(item)\n\
67-
is true. If function is None, return the items that are true.";
68-
69-
extend_class!(context, filter_type, {
70+
PyFilter::extend_class(context, &context.filter_type);
71+
extend_class!(context, &context.filter_type, {
7072
"__new__" => context.new_rustfunc(filter_new),
71-
"__doc__" => context.new_str(filter_doc.to_string()),
72-
"__next__" => context.new_rustfunc(PyFilterRef::next),
73-
"__iter__" => context.new_rustfunc(PyFilterRef::iter),
7473
});
7574
}

vm/src/obj/objiter.rs

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
use std::cell::Cell;
66

7-
use crate::pyobject::{PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol};
7+
use crate::pyobject::{
8+
PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol,
9+
};
810
use crate::vm::VirtualMachine;
911

1012
use super::objtype;
@@ -75,6 +77,7 @@ pub fn new_stop_iteration(vm: &VirtualMachine) -> PyObjectRef {
7577
vm.new_exception(stop_iteration_type, "End of iterator".to_string())
7678
}
7779

80+
#[pyclass]
7881
#[derive(Debug)]
7982
pub struct PySequenceIterator {
8083
pub position: Cell<usize>,
@@ -87,10 +90,10 @@ impl PyValue for PySequenceIterator {
8790
}
8891
}
8992

90-
type PySequenceIteratorRef = PyRef<PySequenceIterator>;
91-
92-
impl PySequenceIteratorRef {
93-
fn next(self, vm: &VirtualMachine) -> PyResult {
93+
#[pyimpl]
94+
impl PySequenceIterator {
95+
#[pymethod(name = "__next__")]
96+
fn next(&self, vm: &VirtualMachine) -> PyResult {
9497
let number = vm.ctx.new_int(self.position.get());
9598
match vm.call_method(&self.obj, "__getitem__", vec![number]) {
9699
Ok(val) => {
@@ -105,16 +108,12 @@ impl PySequenceIteratorRef {
105108
}
106109
}
107110

108-
fn iter(self, _vm: &VirtualMachine) -> Self {
109-
self
111+
#[pymethod(name = "__iter__")]
112+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
113+
zelf
110114
}
111115
}
112116

113117
pub fn init(context: &PyContext) {
114-
let iter_type = &context.iter_type;
115-
116-
extend_class!(context, iter_type, {
117-
"__next__" => context.new_rustfunc(PySequenceIteratorRef::next),
118-
"__iter__" => context.new_rustfunc(PySequenceIteratorRef::iter),
119-
});
118+
PySequenceIterator::extend_class(context, &context.iter_type);
120119
}

vm/src/obj/objlist.rs

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ use num_traits::{One, Signed, ToPrimitive, Zero};
88

99
use crate::function::{OptionalArg, PyFuncArgs};
1010
use crate::pyobject::{
11-
IdProtocol, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject,
11+
IdProtocol, PyClassImpl, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue,
12+
TryFromObject,
1213
};
1314
use crate::vm::{ReprGuard, VirtualMachine};
1415

@@ -776,6 +777,7 @@ fn list_sort(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
776777
Ok(vm.get_none())
777778
}
778779

780+
#[pyclass]
779781
#[derive(Debug)]
780782
pub struct PyListIterator {
781783
pub position: Cell<usize>,
@@ -788,10 +790,10 @@ impl PyValue for PyListIterator {
788790
}
789791
}
790792

791-
type PyListIteratorRef = PyRef<PyListIterator>;
792-
793-
impl PyListIteratorRef {
794-
fn next(self, vm: &VirtualMachine) -> PyResult {
793+
#[pyimpl]
794+
impl PyListIterator {
795+
#[pymethod(name = "__next__")]
796+
fn next(&self, vm: &VirtualMachine) -> PyResult {
795797
if self.position.get() < self.list.elements.borrow().len() {
796798
let ret = self.list.elements.borrow()[self.position.get()].clone();
797799
self.position.set(self.position.get() + 1);
@@ -801,8 +803,9 @@ impl PyListIteratorRef {
801803
}
802804
}
803805

804-
fn iter(self, _vm: &VirtualMachine) -> Self {
805-
self
806+
#[pymethod(name = "__iter__")]
807+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
808+
zelf
806809
}
807810
}
808811

@@ -848,9 +851,5 @@ pub fn init(context: &PyContext) {
848851
"remove" => context.new_rustfunc(PyListRef::remove)
849852
});
850853

851-
let listiterator_type = &context.listiterator_type;
852-
extend_class!(context, listiterator_type, {
853-
"__next__" => context.new_rustfunc(PyListIteratorRef::next),
854-
"__iter__" => context.new_rustfunc(PyListIteratorRef::iter),
855-
});
854+
PyListIterator::extend_class(context, &context.listiterator_type);
856855
}

vm/src/obj/objmap.rs

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
use crate::function::Args;
2-
use crate::pyobject::{PyContext, PyObjectRef, PyRef, PyResult, PyValue};
2+
use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue};
33
use crate::vm::VirtualMachine;
44

55
use super::objiter;
66
use super::objtype::PyClassRef;
77

8+
/// map(func, *iterables) --> map object
9+
///
10+
/// Make an iterator that computes the function using arguments from
11+
/// each of the iterables. Stops when the shortest iterable is exhausted.
12+
#[pyclass]
813
#[derive(Debug)]
914
pub struct PyMap {
1015
mapper: PyObjectRef,
@@ -35,8 +40,10 @@ fn map_new(
3540
.into_ref_with_type(vm, cls.clone())
3641
}
3742

38-
impl PyMapRef {
39-
fn next(self, vm: &VirtualMachine) -> PyResult {
43+
#[pyimpl]
44+
impl PyMap {
45+
#[pymethod(name = "__next__")]
46+
fn next(&self, vm: &VirtualMachine) -> PyResult {
4047
let next_objs = self
4148
.iterators
4249
.iter()
@@ -47,22 +54,15 @@ impl PyMapRef {
4754
vm.invoke(self.mapper.clone(), next_objs)
4855
}
4956

50-
fn iter(self, _vm: &VirtualMachine) -> Self {
51-
self
57+
#[pymethod(name = "__iter__")]
58+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
59+
zelf
5260
}
5361
}
5462

5563
pub fn init(context: &PyContext) {
56-
let map_type = &context.map_type;
57-
58-
let map_doc = "map(func, *iterables) --> map object\n\n\
59-
Make an iterator that computes the function using arguments from\n\
60-
each of the iterables. Stops when the shortest iterable is exhausted.";
61-
62-
extend_class!(context, map_type, {
64+
PyMap::extend_class(context, &context.map_type);
65+
extend_class!(context, &context.map_type, {
6366
"__new__" => context.new_rustfunc(map_new),
64-
"__next__" => context.new_rustfunc(PyMapRef::next),
65-
"__iter__" => context.new_rustfunc(PyMapRef::iter),
66-
"__doc__" => context.new_str(map_doc.to_string())
6767
});
6868
}

0 commit comments

Comments
 (0)