Skip to content

Commit 98d90c8

Browse files
committed
Add itertools.takewhile
1 parent 95a894a commit 98d90c8

File tree

2 files changed

+109
-1
lines changed

2 files changed

+109
-1
lines changed

tests/snippets/stdlib_itertools.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,47 @@
8282
r = itertools.repeat(1, -1)
8383
with assertRaises(StopIteration):
8484
next(r)
85+
86+
87+
# itertools.takewhile tests
88+
89+
from itertools import takewhile as tw
90+
91+
t = tw(lambda n: n < 5, [1, 2, 5, 1, 3])
92+
assert next(t) == 1
93+
assert next(t) == 2
94+
with assertRaises(StopIteration):
95+
next(t)
96+
97+
# not iterable
98+
with assertRaises(TypeError):
99+
tw(lambda n: n < 1, 1)
100+
101+
# not callable
102+
t = tw(5, [1, 2])
103+
with assertRaises(TypeError):
104+
next(t)
105+
106+
# non-bool predicate
107+
t = tw(lambda n: n, [1, 2, 0])
108+
assert next(t) == 1
109+
assert next(t) == 2
110+
with assertRaises(StopIteration):
111+
next(t)
112+
113+
# bad predicate prototype
114+
t = tw(lambda: True, [1])
115+
with assertRaises(TypeError):
116+
next(t)
117+
118+
# StopIteration before attempting to call (bad) predicate
119+
t = tw(lambda: True, [])
120+
with assertRaises(StopIteration):
121+
next(t)
122+
123+
# doesn't try again after the first predicate failure
124+
t = tw(lambda n: n < 1, [1, 0])
125+
with assertRaises(StopIteration):
126+
next(t)
127+
with assertRaises(StopIteration):
128+
next(t)

vm/src/stdlib/itertools.rs

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ use std::ops::{AddAssign, SubAssign};
55
use num_bigint::BigInt;
66

77
use crate::function::OptionalArg;
8+
use crate::obj::objbool;
89
use crate::obj::objint::{PyInt, PyIntRef};
9-
use crate::obj::objiter::new_stop_iteration;
10+
use crate::obj::objiter::{call_next, get_iter, new_stop_iteration};
1011
use crate::obj::objtype::PyClassRef;
1112
use crate::pyobject::{PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue};
1213
use crate::vm::VirtualMachine;
@@ -121,6 +122,65 @@ impl PyItertoolsRepeat {
121122
}
122123
}
123124

125+
#[pyclass]
126+
#[derive(Debug)]
127+
struct PyItertoolsTakewhile {
128+
predicate: PyObjectRef,
129+
iterable: PyObjectRef,
130+
stop_flag: RefCell<bool>,
131+
}
132+
133+
impl PyValue for PyItertoolsTakewhile {
134+
fn class(vm: &VirtualMachine) -> PyClassRef {
135+
vm.class("itertools", "takewhile")
136+
}
137+
}
138+
139+
#[pyimpl]
140+
impl PyItertoolsTakewhile {
141+
#[pymethod(name = "__new__")]
142+
fn new(
143+
_cls: PyClassRef,
144+
predicate: PyObjectRef,
145+
iterable: PyObjectRef,
146+
vm: &VirtualMachine,
147+
) -> PyResult {
148+
let iter = get_iter(vm, &iterable)?;
149+
150+
Ok(PyItertoolsTakewhile {
151+
predicate: predicate,
152+
iterable: iter,
153+
stop_flag: RefCell::new(false),
154+
}
155+
.into_ref(vm)
156+
.into_object())
157+
}
158+
159+
#[pymethod(name = "__next__")]
160+
fn next(&self, vm: &VirtualMachine) -> PyResult {
161+
if *self.stop_flag.borrow() {
162+
return Err(new_stop_iteration(vm));
163+
}
164+
165+
// might be StopIteration or anything else, which is propaged upwwards
166+
let obj = call_next(vm, &self.iterable)?;
167+
168+
let verdict = vm.invoke(self.predicate.clone(), vec![obj.clone()])?;
169+
let verdict = objbool::boolval(vm, verdict)?;
170+
if verdict {
171+
Ok(obj)
172+
} else {
173+
*self.stop_flag.borrow_mut() = true;
174+
Err(new_stop_iteration(vm))
175+
}
176+
}
177+
178+
#[pymethod(name = "__iter__")]
179+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
180+
zelf
181+
}
182+
}
183+
124184
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
125185
let ctx = &vm.ctx;
126186

@@ -130,8 +190,12 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
130190
let repeat = ctx.new_class("repeat", ctx.object());
131191
PyItertoolsRepeat::extend_class(ctx, &repeat);
132192

193+
let takewhile = ctx.new_class("takewhile", ctx.object());
194+
PyItertoolsTakewhile::extend_class(ctx, &takewhile);
195+
133196
py_module!(vm, "itertools", {
134197
"count" => count,
135198
"repeat" => repeat,
199+
"takewhile" => takewhile,
136200
})
137201
}

0 commit comments

Comments
 (0)