Skip to content

Commit cba8aa9

Browse files
committed
Drop iter_type_init, explicitly define __iter__ for iterators
1 parent 016ecf2 commit cba8aa9

File tree

5 files changed

+36
-51
lines changed

5 files changed

+36
-51
lines changed

vm/src/obj/objenumerate.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,17 @@ impl PyEnumerateRef {
5757

5858
Ok(result)
5959
}
60+
61+
fn iter(self, _vm: &VirtualMachine) -> Self {
62+
self
63+
}
6064
}
6165

6266
pub fn init(context: &PyContext) {
6367
let enumerate_type = &context.enumerate_type;
64-
objiter::iter_type_init(context, enumerate_type);
6568
extend_class!(context, enumerate_type, {
6669
"__new__" => context.new_rustfunc(enumerate_new),
67-
"__next__" => context.new_rustfunc(PyEnumerateRef::next)
70+
"__next__" => context.new_rustfunc(PyEnumerateRef::next),
71+
"__iter__" => context.new_rustfunc(PyEnumerateRef::iter),
6872
});
6973
}

vm/src/obj/objfilter.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,15 @@ impl PyFilterRef {
5252
}
5353
}
5454
}
55+
56+
fn iter(self, _vm: &VirtualMachine) -> Self {
57+
self
58+
}
5559
}
5660

5761
pub fn init(context: &PyContext) {
5862
let filter_type = &context.filter_type;
5963

60-
objiter::iter_type_init(context, filter_type);
61-
6264
let filter_doc =
6365
"filter(function or None, iterable) --> filter object\n\n\
6466
Return an iterator yielding those items of iterable for which function(item)\n\
@@ -67,6 +69,7 @@ pub fn init(context: &PyContext) {
6769
extend_class!(context, filter_type, {
6870
"__new__" => context.new_rustfunc(filter_new),
6971
"__doc__" => context.new_str(filter_doc.to_string()),
70-
"__next__" => context.new_rustfunc(PyFilterRef::next)
72+
"__next__" => context.new_rustfunc(PyFilterRef::next),
73+
"__iter__" => context.new_rustfunc(PyFilterRef::iter),
7174
});
7275
}

vm/src/obj/objiter.rs

Lines changed: 13 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,14 @@
22
* Various types to support iteration.
33
*/
44

5-
use crate::function::PyFuncArgs;
6-
use crate::pyobject::{PyContext, PyIteratorValue, PyObjectRef, PyResult, TypeProtocol};
5+
use crate::pyobject::{PyContext, PyIteratorValue, PyObjectRef, PyRef, PyResult};
76
use crate::vm::VirtualMachine;
87

98
use super::objbytearray::PyByteArray;
109
use super::objbytes::PyBytes;
1110
use super::objrange::PyRange;
1211
use super::objsequence;
1312
use super::objtype;
14-
use crate::obj::objtype::PyClassRef;
1513

1614
/*
1715
* This helper function is called at multiple places. First, it is called
@@ -73,41 +71,12 @@ pub fn new_stop_iteration(vm: &VirtualMachine) -> PyObjectRef {
7371
vm.new_exception(stop_iteration_type, "End of iterator".to_string())
7472
}
7573

76-
/// Common setup for iter types, adds __iter__ method
77-
pub fn iter_type_init(context: &PyContext, iter_type: &PyClassRef) {
78-
let iter_func = {
79-
let cloned_iter_type = iter_type.clone();
80-
move |vm: &VirtualMachine, args: PyFuncArgs| {
81-
arg_check!(
82-
vm,
83-
args,
84-
required = [(iter, Some(cloned_iter_type.clone()))]
85-
);
86-
// Return self:
87-
Ok(iter.clone())
88-
}
89-
};
90-
91-
extend_class!(context, iter_type, {
92-
"__iter__" => context.new_rustfunc(iter_func)
93-
});
94-
}
95-
96-
// Sequence iterator:
97-
fn iter_new(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
98-
arg_check!(vm, args, required = [(iter_target, None)]);
74+
type PyIteratorValueRef = PyRef<PyIteratorValue>;
9975

100-
get_iter(vm, iter_target)
101-
}
102-
103-
fn iter_next(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
104-
arg_check!(vm, args, required = [(iter, Some(vm.ctx.iter_type()))]);
105-
106-
if let Some(PyIteratorValue {
107-
ref position,
108-
iterated_obj: ref iterated_obj_ref,
109-
}) = iter.payload()
110-
{
76+
impl PyIteratorValueRef {
77+
fn next(self, vm: &VirtualMachine) -> PyResult {
78+
let position = &self.position;
79+
let iterated_obj_ref = &self.iterated_obj;
11180
if let Some(range) = iterated_obj_ref.payload::<PyRange>() {
11281
if let Some(int) = range.get(position.get()) {
11382
position.set(position.get() + 1);
@@ -141,8 +110,10 @@ fn iter_next(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
141110
Err(new_stop_iteration(vm))
142111
}
143112
}
144-
} else {
145-
panic!("NOT IMPL");
113+
}
114+
115+
fn iter(self, _vm: &VirtualMachine) -> Self {
116+
self
146117
}
147118
}
148119

@@ -155,10 +126,9 @@ pub fn init(context: &PyContext) {
155126
supply its own iterator, or be a sequence.\n\
156127
In the second form, the callable is called until it returns the sentinel.";
157128

158-
iter_type_init(context, iter_type);
159129
extend_class!(context, iter_type, {
160-
"__new__" => context.new_rustfunc(iter_new),
161-
"__next__" => context.new_rustfunc(iter_next),
162-
"__doc__" => context.new_str(iter_doc.to_string())
130+
"__next__" => context.new_rustfunc(PyIteratorValueRef::next),
131+
"__iter__" => context.new_rustfunc(PyIteratorValueRef::iter),
132+
"__doc__" => context.new_str(iter_doc.to_string()),
163133
});
164134
}

vm/src/obj/objmap.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ impl PyMapRef {
4646
// the mapper itself can raise StopIteration which does stop the map iteration
4747
vm.invoke(self.mapper.clone(), next_objs)
4848
}
49+
50+
fn iter(self, _vm: &VirtualMachine) -> Self {
51+
self
52+
}
4953
}
5054

5155
pub fn init(context: &PyContext) {
@@ -55,10 +59,10 @@ pub fn init(context: &PyContext) {
5559
Make an iterator that computes the function using arguments from\n\
5660
each of the iterables. Stops when the shortest iterable is exhausted.";
5761

58-
objiter::iter_type_init(context, map_type);
5962
extend_class!(context, map_type, {
6063
"__new__" => context.new_rustfunc(map_new),
6164
"__next__" => context.new_rustfunc(PyMapRef::next),
65+
"__iter__" => context.new_rustfunc(PyMapRef::iter),
6266
"__doc__" => context.new_str(map_doc.to_string())
6367
});
6468
}

vm/src/obj/objzip.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,17 @@ impl PyZipRef {
4040
Ok(vm.ctx.new_tuple(next_objs))
4141
}
4242
}
43+
44+
fn iter(self, _vm: &VirtualMachine) -> Self {
45+
self
46+
}
4347
}
4448

4549
pub fn init(context: &PyContext) {
4650
let zip_type = &context.zip_type;
47-
objiter::iter_type_init(context, zip_type);
4851
extend_class!(context, zip_type, {
4952
"__new__" => context.new_rustfunc(zip_new),
50-
"__next__" => context.new_rustfunc(PyZipRef::next)
53+
"__next__" => context.new_rustfunc(PyZipRef::next),
54+
"__iter__" => context.new_rustfunc(PyZipRef::iter),
5155
});
5256
}

0 commit comments

Comments
 (0)