Skip to content

Commit 52149d0

Browse files
committed
Add remaining methods to sequence iterator.
1 parent 06c4127 commit 52149d0

File tree

4 files changed

+53
-20
lines changed

4 files changed

+53
-20
lines changed

Lib/test/test_iter.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -155,19 +155,13 @@ def test_iter_class_iter(self):
155155
self.check_iterator(iter(IteratingSequenceClass(10)), list(range(10)))
156156

157157
# Test for loop on a sequence class without __iter__
158-
# TODO: RUSTPYTHON
159-
@unittest.expectedFailure
160158
def test_seq_class_for(self):
161159
self.check_for_loop(SequenceClass(10), list(range(10)))
162160

163161
# Test iter() on a sequence class without __iter__
164-
# TODO: RUSTPYTHON
165-
@unittest.expectedFailure
166162
def test_seq_class_iter(self):
167163
self.check_iterator(iter(SequenceClass(10)), list(range(10)))
168164

169-
# TODO: RUSTPYTHON
170-
@unittest.expectedFailure
171165
def test_mutating_seq_class_iter_pickle(self):
172166
orig = SequenceClass(5)
173167
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
@@ -204,8 +198,6 @@ def test_mutating_seq_class_iter_pickle(self):
204198
self.assertTrue(isinstance(it, collections.abc.Iterator))
205199
self.assertEqual(list(it), [])
206200

207-
# TODO: RUSTPYTHON
208-
@unittest.expectedFailure
209201
def test_mutating_seq_class_exhausted_iter(self):
210202
a = SequenceClass(5)
211203
exhit = iter(a)
@@ -908,8 +900,6 @@ def test_sinkstate_string(self):
908900
self.assertEqual(list(b), ['a', 'b', 'c', 'd', 'e'])
909901
self.assertEqual(list(b), [])
910902

911-
# TODO: RUSTPYTHON
912-
@unittest.expectedFailure
913903
def test_sinkstate_sequence(self):
914904
# This used to fail
915905
a = SequenceClass(5)
@@ -1004,8 +994,6 @@ def test_iter_overflow(self):
1004994
with self.assertRaises(OverflowError):
1005995
next(it)
1006996

1007-
# TODO: RUSTPYTHON
1008-
@unittest.expectedFailure
1009997
def test_iter_neg_setstate(self):
1010998
it = iter(UnlimitedSequenceClass())
1011999
it.__setstate__(-42)

vm/src/builtins/iter.rs

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
use crossbeam_utils::atomic::AtomicCell;
66

77
use super::pytype::PyTypeRef;
8+
use super::{int, PyInt};
89
use crate::slots::PyIter;
910
use crate::vm::VirtualMachine;
1011
use crate::{
@@ -24,8 +25,9 @@ pub enum IterStatus {
2425
#[pyclass(module = false, name = "iterator")]
2526
#[derive(Debug)]
2627
pub struct PySequenceIterator {
27-
pub position: AtomicCell<isize>,
28+
pub position: AtomicCell<usize>,
2829
pub obj: PyObjectRef,
30+
pub status: AtomicCell<IterStatus>,
2931
}
3032

3133
impl PyValue for PySequenceIterator {
@@ -36,26 +38,69 @@ impl PyValue for PySequenceIterator {
3638

3739
#[pyimpl(with(PyIter))]
3840
impl PySequenceIterator {
39-
pub fn new_forward(obj: PyObjectRef) -> Self {
41+
pub fn new(obj: PyObjectRef) -> Self {
4042
Self {
4143
position: AtomicCell::new(0),
4244
obj,
45+
status: AtomicCell::new(IterStatus::Active),
46+
}
47+
}
48+
49+
#[pymethod(magic)]
50+
fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef {
51+
match self.status.load() {
52+
IterStatus::Active => {
53+
let pos = self.position.load();
54+
// return NotImplemented if no length is around.
55+
vm.obj_len(&self.obj)
56+
.map_or(vm.ctx.not_implemented(), |len| {
57+
PyInt::from(len.saturating_sub(pos)).into_object(vm)
58+
})
59+
}
60+
IterStatus::Exhausted => PyInt::from(0).into_object(vm),
4361
}
4462
}
4563

4664
#[pymethod(magic)]
47-
fn length_hint(&self, vm: &VirtualMachine) -> PyResult<isize> {
48-
let pos = self.position.load();
49-
let len = vm.obj_len(&self.obj)?;
50-
Ok(len as isize - pos)
65+
fn reduce(&self, vm: &VirtualMachine) -> PyResult {
66+
let iter = vm.get_attribute(vm.builtins.clone(), "iter")?;
67+
Ok(match self.status.load() {
68+
IterStatus::Exhausted => vm
69+
.ctx
70+
.new_tuple(vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])]),
71+
IterStatus::Active => vm.ctx.new_tuple(vec![
72+
iter,
73+
vm.ctx.new_tuple(vec![self.obj.clone()]),
74+
vm.ctx.new_int(self.position.load()),
75+
]),
76+
})
77+
}
78+
79+
#[pymethod(magic)]
80+
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
81+
// When we're exhausted, just return.
82+
if let IterStatus::Exhausted = self.status.load() {
83+
return Ok(());
84+
}
85+
if let Some(i) = state.payload::<PyInt>() {
86+
self.position
87+
.store(int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0));
88+
Ok(())
89+
} else {
90+
Err(vm.new_type_error("an integer is required.".to_owned()))
91+
}
5192
}
5293
}
5394

5495
impl PyIter for PySequenceIterator {
5596
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult {
97+
if let IterStatus::Exhausted = zelf.status.load() {
98+
return Err(vm.new_stop_iteration());
99+
}
56100
let pos = zelf.position.fetch_add(1);
57101
match zelf.obj.get_item(pos, vm) {
58102
Err(ref e) if e.isinstance(&vm.ctx.exceptions.index_error) => {
103+
zelf.status.store(IterStatus::Exhausted);
59104
Err(vm.new_stop_iteration())
60105
}
61106
// also catches stop_iteration => stop_iteration

vm/src/iterator.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ pub fn get_iter(vm: &VirtualMachine, iter_target: PyObjectRef) -> PyResult {
3636
vm.get_method_or_type_error(iter_target.clone(), "__getitem__", || {
3737
format!("'{}' object is not iterable", iter_target.class().name)
3838
})?;
39-
Ok(PySequenceIterator::new_forward(iter_target)
39+
Ok(PySequenceIterator::new(iter_target)
4040
.into_ref(vm)
4141
.into_object())
4242
}

vm/src/pyobject.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,7 @@ impl<T> PyIterable<T> {
741741
pub fn iter<'a>(&self, vm: &'a VirtualMachine) -> PyResult<PyIterator<'a, T>> {
742742
let iter_obj = match self.iterfn {
743743
Some(f) => f(self.iterable.clone(), vm)?,
744-
None => PySequenceIterator::new_forward(self.iterable.clone()).into_object(vm),
744+
None => PySequenceIterator::new(self.iterable.clone()).into_object(vm),
745745
};
746746

747747
let length_hint = iterator::length_hint(vm, iter_obj.clone())?;

0 commit comments

Comments
 (0)