Skip to content

Commit fdfdfbc

Browse files
authored
Merge pull request RustPython#2473 from rickygao/itertools-support
Follow the latest `itertools` module
2 parents 9439212 + 52242d4 commit fdfdfbc

File tree

2 files changed

+152
-35
lines changed

2 files changed

+152
-35
lines changed

Lib/test/test_itertools.py

Lines changed: 84 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import sys
1313
import struct
1414
import threading
15+
import gc
16+
1517
maxsize = support.MAX_Py_ssize_t
1618
minsize = -maxsize-1
1719

@@ -193,7 +195,6 @@ def test_chain_reducible(self):
193195
self.assertRaises(TypeError, list, oper(chain(2, 3)))
194196
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
195197
self.pickletest(proto, chain('abc', 'def'), compare=list('abcdef'))
196-
197198
# TODO: RUSTPYTHON
198199
@unittest.expectedFailure
199200
def test_chain_setstate(self):
@@ -208,7 +209,6 @@ def test_chain_setstate(self):
208209
it = chain()
209210
it.__setstate__((iter(['abc', 'def']), iter(['ghi'])))
210211
self.assertEqual(list(it), ['ghi', 'a', 'b', 'c', 'd', 'e', 'f'])
211-
212212
# TODO: RUSTPYTHON
213213
@unittest.expectedFailure
214214
def test_combinations(self):
@@ -1053,6 +1053,25 @@ def run(r1, r2):
10531053
self.assertEqual(next(it), (1, 2))
10541054
self.assertRaises(RuntimeError, next, it)
10551055

1056+
def test_pairwise(self):
1057+
self.assertEqual(list(pairwise('')), [])
1058+
self.assertEqual(list(pairwise('a')), [])
1059+
self.assertEqual(list(pairwise('ab')),
1060+
[('a', 'b')]),
1061+
self.assertEqual(list(pairwise('abcde')),
1062+
[('a', 'b'), ('b', 'c'), ('c', 'd'), ('d', 'e')])
1063+
self.assertEqual(list(pairwise(range(10_000))),
1064+
list(zip(range(10_000), range(1, 10_000))))
1065+
1066+
with self.assertRaises(TypeError):
1067+
pairwise() # too few arguments
1068+
with self.assertRaises(TypeError):
1069+
pairwise('abc', 10) # too many arguments
1070+
with self.assertRaises(TypeError):
1071+
pairwise(iterable='abc') # keyword arguments
1072+
with self.assertRaises(TypeError):
1073+
pairwise(None) # non-iterable argument
1074+
10561075
def test_product(self):
10571076
for args, result in [
10581077
([], [()]), # zero iterables
@@ -1609,6 +1628,51 @@ def test_StopIteration(self):
16091628
self.assertRaises(StopIteration, next, f(lambda x:x, []))
16101629
self.assertRaises(StopIteration, next, f(lambda x:x, StopNow()))
16111630

1631+
@support.cpython_only
1632+
def test_combinations_result_gc(self):
1633+
# bpo-42536: combinations's tuple-reuse speed trick breaks the GC's
1634+
# assumptions about what can be untracked. Make sure we re-track result
1635+
# tuples whenever we reuse them.
1636+
it = combinations([None, []], 1)
1637+
next(it)
1638+
gc.collect()
1639+
# That GC collection probably untracked the recycled internal result
1640+
# tuple, which has the value (None,). Make sure it's re-tracked when
1641+
# it's mutated and returned from __next__:
1642+
self.assertTrue(gc.is_tracked(next(it)))
1643+
1644+
@support.cpython_only
1645+
def test_combinations_with_replacement_result_gc(self):
1646+
# Ditto for combinations_with_replacement.
1647+
it = combinations_with_replacement([None, []], 1)
1648+
next(it)
1649+
gc.collect()
1650+
self.assertTrue(gc.is_tracked(next(it)))
1651+
1652+
@support.cpython_only
1653+
def test_permutations_result_gc(self):
1654+
# Ditto for permutations.
1655+
it = permutations([None, []], 1)
1656+
next(it)
1657+
gc.collect()
1658+
self.assertTrue(gc.is_tracked(next(it)))
1659+
1660+
@support.cpython_only
1661+
def test_product_result_gc(self):
1662+
# Ditto for product.
1663+
it = product([None, []])
1664+
next(it)
1665+
gc.collect()
1666+
self.assertTrue(gc.is_tracked(next(it)))
1667+
1668+
@support.cpython_only
1669+
def test_zip_longest_result_gc(self):
1670+
# Ditto for zip_longest.
1671+
it = zip_longest([[]])
1672+
gc.collect()
1673+
self.assertTrue(gc.is_tracked(next(it)))
1674+
1675+
16121676
class TestExamples(unittest.TestCase):
16131677

16141678
def test_accumulate(self):
@@ -1848,6 +1912,10 @@ def test_islice(self):
18481912
a = []
18491913
self.makecycle(islice([a]*2, None), a)
18501914

1915+
def test_pairwise(self):
1916+
a = []
1917+
self.makecycle(pairwise([a]*5), a)
1918+
18511919
def test_permutations(self):
18521920
a = []
18531921
self.makecycle(permutations([1,2,a,3], 3), a)
@@ -1946,6 +2014,7 @@ def L(seqn):
19462014

19472015

19482016
class TestVariousIteratorArgs(unittest.TestCase):
2017+
19492018
def test_accumulate(self):
19502019
s = [1,2,3,4,5]
19512020
r = [1,3,6,10,15]
@@ -2055,6 +2124,17 @@ def test_islice(self):
20552124
self.assertRaises(TypeError, islice, N(s), 10)
20562125
self.assertRaises(ZeroDivisionError, list, islice(E(s), 10))
20572126

2127+
def test_pairwise(self):
2128+
for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
2129+
for g in (G, I, Ig, S, L, R):
2130+
seq = list(g(s))
2131+
expected = list(zip(seq, seq[1:]))
2132+
actual = list(pairwise(g(s)))
2133+
self.assertEqual(actual, expected)
2134+
self.assertRaises(TypeError, pairwise, X(s))
2135+
self.assertRaises(TypeError, pairwise, N(s))
2136+
self.assertRaises(ZeroDivisionError, list, pairwise(E(s)))
2137+
20582138
def test_starmap(self):
20592139
for s in (range(10), range(0), range(100), (7,11), range(20,50,5)):
20602140
for g in (G, I, Ig, S, L, R):
@@ -2356,7 +2436,7 @@ def test_permutations_sizeof(self):
23562436
... "Count how many times the predicate is true"
23572437
... return sum(map(pred, iterable))
23582438
2359-
>>> def padnone(iterable):
2439+
>>> def pad_none(iterable):
23602440
... "Returns the sequence elements and then returns None indefinitely"
23612441
... return chain(iterable, repeat(None))
23622442
@@ -2378,15 +2458,6 @@ def test_permutations_sizeof(self):
23782458
... else:
23792459
... return starmap(func, repeat(args, times))
23802460
2381-
>>> def pairwise(iterable):
2382-
... "s -> (s0,s1), (s1,s2), (s2, s3), ..."
2383-
... a, b = tee(iterable)
2384-
... try:
2385-
... next(b)
2386-
... except StopIteration:
2387-
... pass
2388-
... return zip(a, b)
2389-
23902461
>>> def grouper(n, iterable, fillvalue=None):
23912462
... "grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx"
23922463
... args = [iter(iterable)] * n
@@ -2517,16 +2588,7 @@ def test_permutations_sizeof(self):
25172588
>>> take(5, map(int, repeatfunc(random.random)))
25182589
[0, 0, 0, 0, 0]
25192590
2520-
>>> list(pairwise('abcd'))
2521-
[('a', 'b'), ('b', 'c'), ('c', 'd')]
2522-
2523-
>>> list(pairwise([]))
2524-
[]
2525-
2526-
>>> list(pairwise('a'))
2527-
[]
2528-
2529-
>>> list(islice(padnone('abc'), 0, 6))
2591+
>>> list(islice(pad_none('abc'), 0, 6))
25302592
['a', 'b', 'c', None, None, None]
25312593
25322594
>>> list(ncycles('abc', 3))

vm/src/stdlib/itertools.rs

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -804,10 +804,20 @@ mod decl {
804804
#[derive(Debug)]
805805
struct PyItertoolsAccumulate {
806806
iterable: PyObjectRef,
807-
binop: PyObjectRef,
807+
binop: Option<PyObjectRef>,
808+
initial: Option<PyObjectRef>,
808809
acc_value: PyRwLock<Option<PyObjectRef>>,
809810
}
810811

812+
#[derive(FromArgs)]
813+
struct AccumulateArgs {
814+
iterable: PyObjectRef,
815+
#[pyarg(any, optional)]
816+
func: OptionalOption<PyObjectRef>,
817+
#[pyarg(named, optional)]
818+
initial: OptionalOption<PyObjectRef>,
819+
}
820+
811821
impl PyValue for PyItertoolsAccumulate {
812822
fn class(_vm: &VirtualMachine) -> &PyTypeRef {
813823
Self::static_type()
@@ -819,15 +829,15 @@ mod decl {
819829
#[pyslot]
820830
fn tp_new(
821831
cls: PyTypeRef,
822-
iterable: PyObjectRef,
823-
binop: OptionalArg<PyObjectRef>,
832+
args: AccumulateArgs,
824833
vm: &VirtualMachine,
825834
) -> PyResult<PyRef<Self>> {
826-
let iter = get_iter(vm, iterable)?;
835+
let iter = get_iter(vm, args.iterable)?;
827836

828837
PyItertoolsAccumulate {
829838
iterable: iter,
830-
binop: binop.unwrap_or_none(vm),
839+
binop: args.func.flatten(),
840+
initial: args.initial.flatten(),
831841
acc_value: PyRwLock::new(None),
832842
}
833843
.into_ref_with_type(vm, cls)
@@ -836,17 +846,19 @@ mod decl {
836846
impl PyIter for PyItertoolsAccumulate {
837847
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult {
838848
let iterable = &zelf.iterable;
839-
let obj = call_next(vm, iterable)?;
840849

841850
let acc_value = zelf.acc_value.read().clone();
842851

843852
let next_acc_value = match acc_value {
844-
None => obj,
853+
None => match &zelf.initial {
854+
None => call_next(vm, iterable)?,
855+
Some(obj) => obj.clone(),
856+
},
845857
Some(value) => {
846-
if vm.is_none(&zelf.binop) {
847-
vm._add(&value, &obj)?
848-
} else {
849-
vm.invoke(&zelf.binop, vec![value, obj])?
858+
let obj = call_next(vm, iterable)?;
859+
match &zelf.binop {
860+
None => vm._add(&value, &obj)?,
861+
Some(op) => vm.invoke(op, vec![value, obj])?,
850862
}
851863
}
852864
};
@@ -1387,7 +1399,7 @@ mod decl {
13871399
}
13881400

13891401
#[derive(FromArgs)]
1390-
struct ZiplongestArgs {
1402+
struct ZipLongestArgs {
13911403
#[pyarg(named, optional)]
13921404
fillvalue: OptionalArg<PyObjectRef>,
13931405
}
@@ -1398,7 +1410,7 @@ mod decl {
13981410
fn tp_new(
13991411
cls: PyTypeRef,
14001412
iterables: Args,
1401-
args: ZiplongestArgs,
1413+
args: ZipLongestArgs,
14021414
vm: &VirtualMachine,
14031415
) -> PyResult<PyRef<Self>> {
14041416
let fillvalue = args.fillvalue.unwrap_or_none(vm);
@@ -1442,4 +1454,47 @@ mod decl {
14421454
}
14431455
}
14441456
}
1457+
1458+
#[pyattr]
1459+
#[pyclass(name = "pairwise")]
1460+
#[derive(Debug)]
1461+
struct PyItertoolsPairwise {
1462+
iterator: PyObjectRef,
1463+
old: PyRwLock<Option<PyObjectRef>>,
1464+
}
1465+
1466+
impl PyValue for PyItertoolsPairwise {
1467+
fn class(_vm: &VirtualMachine) -> &PyTypeRef {
1468+
Self::static_type()
1469+
}
1470+
}
1471+
1472+
#[pyimpl(with(PyIter))]
1473+
impl PyItertoolsPairwise {
1474+
#[pyslot]
1475+
fn tp_new(
1476+
cls: PyTypeRef,
1477+
iterable: PyObjectRef,
1478+
vm: &VirtualMachine,
1479+
) -> PyResult<PyRef<Self>> {
1480+
let iterator = get_iter(vm, iterable)?;
1481+
1482+
PyItertoolsPairwise {
1483+
iterator,
1484+
old: PyRwLock::new(None),
1485+
}
1486+
.into_ref_with_type(vm, cls)
1487+
}
1488+
}
1489+
impl PyIter for PyItertoolsPairwise {
1490+
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult {
1491+
let old = match zelf.old.read().clone() {
1492+
None => call_next(vm, &zelf.iterator)?,
1493+
Some(obj) => obj,
1494+
};
1495+
let new = call_next(vm, &zelf.iterator)?;
1496+
*zelf.old.write() = Some(new.clone());
1497+
Ok(vm.ctx.new_tuple(vec![old, new]))
1498+
}
1499+
}
14451500
}

0 commit comments

Comments
 (0)