Skip to content

Commit 5608808

Browse files
seryoungshim17youknowone
authored andcommitted
Add itertool.combinations.__reduce__ method
Add result in struct PyItertoolsCombinations
1 parent cac9918 commit 5608808

File tree

1 file changed

+74
-27
lines changed

1 file changed

+74
-27
lines changed

vm/src/stdlib/itertools.rs

Lines changed: 74 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,6 +1308,7 @@ mod decl {
13081308
struct PyItertoolsCombinations {
13091309
pool: Vec<PyObjectRef>,
13101310
indices: PyRwLock<Vec<usize>>,
1311+
result: PyRwLock<Option<Vec<usize>>>,
13111312
r: AtomicCell<usize>,
13121313
exhausted: AtomicCell<bool>,
13131314
}
@@ -1341,6 +1342,7 @@ mod decl {
13411342
PyItertoolsCombinations {
13421343
pool,
13431344
indices: PyRwLock::new((0..r).collect()),
1345+
result: PyRwLock::new(None),
13441346
r: AtomicCell::new(r),
13451347
exhausted: AtomicCell::new(r > n),
13461348
}
@@ -1350,7 +1352,39 @@ mod decl {
13501352
}
13511353

13521354
#[pyclass(with(IterNext, Constructor))]
1353-
impl PyItertoolsCombinations {}
1355+
impl PyItertoolsCombinations {
1356+
#[pymethod(magic)]
1357+
fn reduce(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyTupleRef {
1358+
let result = zelf.result.read();
1359+
if let Some(result) = &*result {
1360+
if zelf.exhausted.load() {
1361+
vm.new_tuple((
1362+
zelf.class().to_owned(),
1363+
vm.new_tuple((vm.new_tuple(()), vm.ctx.new_int(zelf.r.load()))),
1364+
))
1365+
} else {
1366+
vm.new_tuple((
1367+
zelf.class().to_owned(),
1368+
vm.new_tuple((
1369+
vm.new_tuple(zelf.pool.clone()),
1370+
vm.ctx.new_int(zelf.r.load()),
1371+
)),
1372+
vm.ctx
1373+
.new_tuple(result.iter().map(|&i| zelf.pool[i].clone()).collect()),
1374+
))
1375+
}
1376+
} else {
1377+
vm.new_tuple((
1378+
zelf.class().to_owned(),
1379+
vm.new_tuple((
1380+
vm.new_tuple(zelf.pool.clone()),
1381+
vm.ctx.new_int(zelf.r.load()),
1382+
)),
1383+
))
1384+
}
1385+
}
1386+
}
1387+
13541388
impl IterNextIterable for PyItertoolsCombinations {}
13551389
impl IterNext for PyItertoolsCombinations {
13561390
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
@@ -1367,38 +1401,51 @@ mod decl {
13671401
return Ok(PyIterReturn::Return(vm.new_tuple(()).into()));
13681402
}
13691403

1370-
let res = vm.ctx.new_tuple(
1371-
zelf.indices
1372-
.read()
1373-
.iter()
1374-
.map(|&i| zelf.pool[i].clone())
1375-
.collect(),
1376-
);
1404+
let mut result = zelf.result.write();
13771405

1378-
let mut indices = zelf.indices.write();
1406+
if let Some(ref mut result) = *result {
1407+
let mut indices = zelf.indices.write();
13791408

1380-
// Scan indices right-to-left until finding one that is not at its maximum (i + n - r).
1381-
let mut idx = r as isize - 1;
1382-
while idx >= 0 && indices[idx as usize] == idx as usize + n - r {
1383-
idx -= 1;
1384-
}
1409+
// Scan indices right-to-left until finding one that is not at its maximum (i + n - r).
1410+
let mut idx = r as isize - 1;
1411+
while idx >= 0 && indices[idx as usize] == idx as usize + n - r {
1412+
idx -= 1;
1413+
}
13851414

1386-
// If no suitable index is found, then the indices are all at
1387-
// their maximum value and we're done.
1388-
if idx < 0 {
1389-
zelf.exhausted.store(true);
1390-
} else {
1391-
// Increment the current index which we know is not at its
1392-
// maximum. Then move back to the right setting each index
1393-
// to its lowest possible value (one higher than the index
1394-
// to its left -- this maintains the sort order invariant).
1395-
indices[idx as usize] += 1;
1396-
for j in idx as usize + 1..r {
1397-
indices[j] = indices[j - 1] + 1;
1415+
// If no suitable index is found, then the indices are all at
1416+
// their maximum value and we're done.
1417+
if idx < 0 {
1418+
zelf.exhausted.store(true);
1419+
return Ok(PyIterReturn::StopIteration(None));
1420+
} else {
1421+
// Increment the current index which we know is not at its
1422+
// maximum. Then move back to the right setting each index
1423+
// to its lowest possible value (one higher than the index
1424+
// to its left -- this maintains the sort order invariant).
1425+
indices[idx as usize] += 1;
1426+
for j in idx as usize + 1..r {
1427+
indices[j] = indices[j - 1] + 1;
1428+
}
1429+
for j in 0..r {
1430+
result[j] = indices[j];
1431+
}
13981432
}
1433+
} else {
1434+
*result = Some((0..r).collect());
13991435
}
14001436

1401-
Ok(PyIterReturn::Return(res.into()))
1437+
Ok(PyIterReturn::Return(
1438+
vm.ctx
1439+
.new_tuple(
1440+
result
1441+
.as_ref()
1442+
.unwrap()
1443+
.iter()
1444+
.map(|&i| zelf.pool[i].clone())
1445+
.collect(),
1446+
)
1447+
.into(),
1448+
))
14021449
}
14031450
}
14041451

0 commit comments

Comments
 (0)