Skip to content

Commit 3aa09dd

Browse files
Merge pull request RustPython#1646 from dralley/itertools
Add itertools.combinations_with_replacement()
2 parents 920ef52 + 4bbca2b commit 3aa09dd

File tree

2 files changed

+118
-0
lines changed

2 files changed

+118
-0
lines changed

tests/snippets/stdlib_itertools.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,25 @@ def assert_matches_seq(it, seq):
405405
with assert_raises(TypeError):
406406
itertools.combinations([1, 2, 3, 4], None)
407407

408+
# itertools.combinations
409+
it = itertools.combinations_with_replacement([1, 2, 3], 0)
410+
assert list(it) == [()]
411+
412+
it = itertools.combinations_with_replacement([1, 2, 3], 1)
413+
assert list(it) == [(1,), (2,), (3,)]
414+
415+
it = itertools.combinations_with_replacement([1, 2, 3], 2)
416+
assert list(it) == [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)]
417+
418+
it = itertools.combinations_with_replacement([1, 2], 3)
419+
assert list(it) == [(1, 1, 1), (1, 1, 2), (1, 2, 2), (2, 2, 2)]
420+
421+
with assert_raises(ValueError):
422+
itertools.combinations_with_replacement([1, 2, 3, 4], -2)
423+
424+
with assert_raises(TypeError):
425+
itertools.combinations_with_replacement([1, 2, 3, 4], None)
426+
408427
# itertools.permutations
409428
it = itertools.permutations([1, 2, 3])
410429
assert list(it) == [(1, 2, 3), (1, 3, 2), (2, 1, 3), (2, 3, 1), (3, 1, 2), (3, 2, 1)]

vm/src/stdlib/itertools.rs

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,6 +1031,100 @@ impl PyItertoolsCombinations {
10311031
}
10321032
}
10331033

1034+
#[pyclass]
1035+
#[derive(Debug)]
1036+
struct PyItertoolsCombinationsWithReplacement {
1037+
pool: Vec<PyObjectRef>,
1038+
indices: RefCell<Vec<usize>>,
1039+
r: Cell<usize>,
1040+
exhausted: Cell<bool>,
1041+
}
1042+
1043+
impl PyValue for PyItertoolsCombinationsWithReplacement {
1044+
fn class(vm: &VirtualMachine) -> PyClassRef {
1045+
vm.class("itertools", "combinations_with_replacement")
1046+
}
1047+
}
1048+
1049+
#[pyimpl]
1050+
impl PyItertoolsCombinationsWithReplacement {
1051+
#[pyslot(new)]
1052+
fn tp_new(
1053+
cls: PyClassRef,
1054+
iterable: PyObjectRef,
1055+
r: PyIntRef,
1056+
vm: &VirtualMachine,
1057+
) -> PyResult<PyRef<Self>> {
1058+
let iter = get_iter(vm, &iterable)?;
1059+
let pool = get_all(vm, &iter)?;
1060+
1061+
let r = r.as_bigint();
1062+
if r.is_negative() {
1063+
return Err(vm.new_value_error("r must be non-negative".to_string()));
1064+
}
1065+
let r = r.to_usize().unwrap();
1066+
1067+
let n = pool.len();
1068+
1069+
PyItertoolsCombinationsWithReplacement {
1070+
pool,
1071+
indices: RefCell::new(vec![0; r]),
1072+
r: Cell::new(r),
1073+
exhausted: Cell::new(n == 0 && r > 0),
1074+
}
1075+
.into_ref_with_type(vm, cls)
1076+
}
1077+
1078+
#[pymethod(name = "__iter__")]
1079+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
1080+
zelf
1081+
}
1082+
1083+
#[pymethod(name = "__next__")]
1084+
fn next(&self, vm: &VirtualMachine) -> PyResult {
1085+
// stop signal
1086+
if self.exhausted.get() {
1087+
return Err(new_stop_iteration(vm));
1088+
}
1089+
1090+
let n = self.pool.len();
1091+
let r = self.r.get();
1092+
1093+
if r == 0 {
1094+
self.exhausted.set(true);
1095+
return Ok(vm.ctx.new_tuple(vec![]));
1096+
}
1097+
1098+
let mut indices = self.indices.borrow_mut();
1099+
1100+
let res = vm
1101+
.ctx
1102+
.new_tuple(indices.iter().map(|&i| self.pool[i].clone()).collect());
1103+
1104+
// Scan indices right-to-left until finding one that is not at its maximum (i + n - r).
1105+
let mut idx = r as isize - 1;
1106+
while idx >= 0 && indices[idx as usize] == n - 1 {
1107+
idx -= 1;
1108+
}
1109+
1110+
// If no suitable index is found, then the indices are all at
1111+
// their maximum value and we're done.
1112+
if idx < 0 {
1113+
self.exhausted.set(true);
1114+
} else {
1115+
let index = indices[idx as usize] + 1;
1116+
1117+
// Increment the current index which we know is not at its
1118+
// maximum. Then set all to the right to the same value.
1119+
for j in idx as usize..r {
1120+
indices[j as usize] = index as usize;
1121+
}
1122+
}
1123+
1124+
Ok(res)
1125+
}
1126+
}
1127+
10341128
#[pyclass]
10351129
#[derive(Debug)]
10361130
struct PyItertoolsPermutations {
@@ -1257,6 +1351,10 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
12571351
let combinations = ctx.new_class("combinations", ctx.object());
12581352
PyItertoolsCombinations::extend_class(ctx, &combinations);
12591353

1354+
let combinations_with_replacement =
1355+
ctx.new_class("combinations_with_replacement", ctx.object());
1356+
PyItertoolsCombinationsWithReplacement::extend_class(ctx, &combinations_with_replacement);
1357+
12601358
let count = ctx.new_class("count", ctx.object());
12611359
PyItertoolsCount::extend_class(ctx, &count);
12621360

@@ -1296,6 +1394,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
12961394
"chain" => chain,
12971395
"compress" => compress,
12981396
"combinations" => combinations,
1397+
"combinations_with_replacement" => combinations_with_replacement,
12991398
"count" => count,
13001399
"cycle" => cycle,
13011400
"dropwhile" => dropwhile,

0 commit comments

Comments
 (0)