Skip to content

Commit ec011f5

Browse files
authored
Merge pull request RustPython#1604 from Space0726/itertools
Implement itertools.zip_longest
2 parents c4466a0 + fbd727e commit ec011f5

File tree

2 files changed

+130
-0
lines changed

2 files changed

+130
-0
lines changed

tests/snippets/stdlib_itertools.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,3 +344,46 @@ def assert_matches_seq(it, seq):
344344

345345
with assert_raises(ValueError):
346346
itertools.combinations([1, 2, 3, 4], -2)
347+
348+
# itertools.zip_longest tests
349+
zl = itertools.zip_longest
350+
assert list(zl(['a', 'b', 'c'], range(3), [9, 8, 7])) \
351+
== [('a', 0, 9), ('b', 1, 8), ('c', 2, 7)]
352+
assert list(zl(['a', 'b', 'c'], range(3), [9, 8, 7, 99])) \
353+
== [('a', 0, 9), ('b', 1, 8), ('c', 2, 7), (None, None, 99)]
354+
assert list(zl(['a', 'b', 'c'], range(3), [9, 8, 7, 99], fillvalue='d')) \
355+
== [('a', 0, 9), ('b', 1, 8), ('c', 2, 7), ('d', 'd', 99)]
356+
357+
assert list(zl(['a', 'b', 'c'])) == [('a',), ('b',), ('c',)]
358+
assert list(zl()) == []
359+
360+
assert list(zl(*zl(['a', 'b', 'c'], range(1, 4)))) \
361+
== [('a', 'b', 'c'), (1, 2, 3)]
362+
assert list(zl(*zl(['a', 'b', 'c'], range(1, 5)))) \
363+
== [('a', 'b', 'c', None), (1, 2, 3, 4)]
364+
assert list(zl(*zl(['a', 'b', 'c'], range(1, 5), fillvalue=100))) \
365+
== [('a', 'b', 'c', 100), (1, 2, 3, 4)]
366+
367+
368+
# test infinite iterator
369+
class Counter(object):
370+
def __init__(self, counter=0):
371+
self.counter = counter
372+
373+
def __next__(self):
374+
self.counter += 1
375+
return self.counter
376+
377+
def __iter__(self):
378+
return self
379+
380+
381+
it = zl(Counter(), Counter(3))
382+
assert next(it) == (1, 4)
383+
assert next(it) == (2, 5)
384+
385+
it = zl([1,2], [3])
386+
assert next(it) == (1, 3)
387+
assert next(it) == (2, None)
388+
with assert_raises(StopIteration):
389+
next(it)

vm/src/stdlib/itertools.rs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,89 @@ impl PyItertoolsCombinations {
953953
}
954954
}
955955

956+
#[pyclass]
957+
#[derive(Debug)]
958+
struct PyItertoolsZiplongest {
959+
iterators: Vec<PyObjectRef>,
960+
fillvalue: PyObjectRef,
961+
numactive: Cell<usize>,
962+
}
963+
964+
impl PyValue for PyItertoolsZiplongest {
965+
fn class(vm: &VirtualMachine) -> PyClassRef {
966+
vm.class("itertools", "zip_longest")
967+
}
968+
}
969+
970+
#[derive(FromArgs)]
971+
struct ZiplongestArgs {
972+
#[pyarg(keyword_only, optional = true)]
973+
fillvalue: OptionalArg<PyObjectRef>,
974+
}
975+
976+
#[pyimpl]
977+
impl PyItertoolsZiplongest {
978+
#[pyslot(new)]
979+
fn tp_new(
980+
cls: PyClassRef,
981+
iterables: Args,
982+
args: ZiplongestArgs,
983+
vm: &VirtualMachine,
984+
) -> PyResult<PyRef<Self>> {
985+
let fillvalue = match args.fillvalue.into_option() {
986+
Some(i) => i,
987+
None => vm.get_none(),
988+
};
989+
990+
let iterators = iterables
991+
.into_iter()
992+
.map(|iterable| get_iter(vm, &iterable))
993+
.collect::<Result<Vec<_>, _>>()?;
994+
995+
let numactive = Cell::new(iterators.len());
996+
997+
PyItertoolsZiplongest {
998+
iterators,
999+
fillvalue,
1000+
numactive,
1001+
}
1002+
.into_ref_with_type(vm, cls)
1003+
}
1004+
1005+
#[pymethod(name = "__next__")]
1006+
fn next(&self, vm: &VirtualMachine) -> PyResult {
1007+
if self.iterators.is_empty() {
1008+
Err(new_stop_iteration(vm))
1009+
} else {
1010+
let mut result: Vec<PyObjectRef> = Vec::new();
1011+
let mut numactive = self.numactive.get();
1012+
1013+
for idx in 0..self.iterators.len() {
1014+
let next_obj = match call_next(vm, &self.iterators[idx]) {
1015+
Ok(obj) => obj,
1016+
Err(err) => {
1017+
if !objtype::isinstance(&err, &vm.ctx.exceptions.stop_iteration) {
1018+
return Err(err);
1019+
}
1020+
numactive -= 1;
1021+
if numactive == 0 {
1022+
return Err(new_stop_iteration(vm));
1023+
}
1024+
self.fillvalue.clone()
1025+
}
1026+
};
1027+
result.push(next_obj);
1028+
}
1029+
Ok(vm.ctx.new_tuple(result))
1030+
}
1031+
}
1032+
1033+
#[pymethod(name = "__iter__")]
1034+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
1035+
zelf
1036+
}
1037+
}
1038+
9561039
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
9571040
let ctx = &vm.ctx;
9581041

@@ -991,6 +1074,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
9911074
let tee = ctx.new_class("tee", ctx.object());
9921075
PyItertoolsTee::extend_class(ctx, &tee);
9931076

1077+
let zip_longest = ctx.new_class("zip_longest", ctx.object());
1078+
PyItertoolsZiplongest::extend_class(ctx, &zip_longest);
1079+
9941080
py_module!(vm, "itertools", {
9951081
"accumulate" => accumulate,
9961082
"chain" => chain,
@@ -1005,5 +1091,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
10051091
"takewhile" => takewhile,
10061092
"tee" => tee,
10071093
"product" => product,
1094+
"zip_longest" => zip_longest,
10081095
})
10091096
}

0 commit comments

Comments
 (0)