Skip to content

Commit 44ceaa8

Browse files
committed
itertools: Apply codereview
1 parent e1854ce commit 44ceaa8

File tree

2 files changed

+18
-20
lines changed

2 files changed

+18
-20
lines changed

vm/src/pyobject.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ where
670670
}
671671
}
672672

673-
#[derive(Clone)]
673+
#[derive(Clone, Debug)]
674674
pub struct PyCallable {
675675
obj: PyObjectRef,
676676
}

vm/src/stdlib/itertools.rs

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::cell::RefCell;
1+
use std::cell::{Cell, RefCell};
22
use std::cmp::Ordering;
33
use std::ops::{AddAssign, SubAssign};
44

@@ -12,7 +12,7 @@ use crate::obj::objint::{PyInt, PyIntRef};
1212
use crate::obj::objiter::{call_next, get_iter, new_stop_iteration};
1313
use crate::obj::objtype;
1414
use crate::obj::objtype::PyClassRef;
15-
use crate::pyobject::{IdProtocol, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue};
15+
use crate::pyobject::{IdProtocol, PyCallable, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue};
1616
use crate::vm::VirtualMachine;
1717

1818
#[pyclass(name = "chain")]
@@ -293,9 +293,9 @@ impl PyItertoolsTakewhile {
293293
#[pyclass]
294294
#[derive(Debug)]
295295
struct PyItertoolsDropwhile {
296-
predicate: PyObjectRef,
296+
predicate: PyCallable,
297297
iterable: PyObjectRef,
298-
start_flag: RefCell<bool>,
298+
start_flag: Cell<bool>,
299299
}
300300

301301
impl PyValue for PyItertoolsDropwhile {
@@ -304,47 +304,45 @@ impl PyValue for PyItertoolsDropwhile {
304304
}
305305
}
306306

307+
type PyItertoolsDropwhileRef = PyRef<PyItertoolsDropwhile>;
308+
307309
#[pyimpl]
308310
impl PyItertoolsDropwhile {
309311
#[pymethod(name = "__new__")]
310312
#[allow(clippy::new_ret_no_self)]
311313
fn new(
312-
_cls: PyClassRef,
313-
predicate: PyObjectRef,
314+
cls: PyClassRef,
315+
predicate: PyCallable,
314316
iterable: PyObjectRef,
315317
vm: &VirtualMachine,
316-
) -> PyResult {
318+
) -> PyResult<PyItertoolsDropwhileRef> {
317319
let iter = get_iter(vm, &iterable)?;
318320

319-
Ok(PyItertoolsDropwhile {
321+
PyItertoolsDropwhile {
320322
predicate,
321323
iterable: iter,
322-
start_flag: RefCell::new(false),
324+
start_flag: Cell::new(false),
323325
}
324-
.into_ref(vm)
325-
.into_object())
326+
.into_ref_with_type(vm, cls)
326327
}
327328

328329
#[pymethod(name = "__next__")]
329330
fn next(&self, vm: &VirtualMachine) -> PyResult {
330331
let predicate = &self.predicate;
331332
let iterable = &self.iterable;
332333

333-
if !*self.start_flag.borrow_mut() {
334+
if !self.start_flag.get() {
334335
loop {
335336
let obj = call_next(vm, iterable)?;
336-
let pred_value = vm.invoke(predicate, vec![obj.clone()])?;
337+
let pred = predicate.clone();
338+
let pred_value = vm.invoke(&pred.into_object(), vec![obj.clone()])?;
337339
if !objbool::boolval(vm, pred_value)? {
338-
*self.start_flag.borrow_mut() = true;
340+
self.start_flag.set(true);
339341
return Ok(obj);
340342
}
341343
}
342344
}
343-
344-
loop {
345-
let obj = call_next(vm, iterable)?;
346-
return Ok(obj);
347-
}
345+
call_next(vm, iterable)
348346
}
349347

350348
#[pymethod(name = "__iter__")]

0 commit comments

Comments
 (0)