@@ -1031,6 +1031,100 @@ impl PyItertoolsCombinations {
1031
1031
}
1032
1032
}
1033
1033
1034
+ #[ pyclass]
1035
+ #[ derive( Debug ) ]
1036
+ struct PyItertoolsCombinationsWithReplacement {
1037
+ pool : Vec < PyObjectRef > ,
1038
+ indices : RefCell < Vec < usize > > ,
1039
+ r : Cell < usize > ,
1040
+ exhausted : Cell < bool > ,
1041
+ }
1042
+
1043
+ impl PyValue for PyItertoolsCombinationsWithReplacement {
1044
+ fn class ( vm : & VirtualMachine ) -> PyClassRef {
1045
+ vm. class ( "itertools" , "combinations_with_replacement" )
1046
+ }
1047
+ }
1048
+
1049
+ #[ pyimpl]
1050
+ impl PyItertoolsCombinationsWithReplacement {
1051
+ #[ pyslot( new) ]
1052
+ fn tp_new (
1053
+ cls : PyClassRef ,
1054
+ iterable : PyObjectRef ,
1055
+ r : PyIntRef ,
1056
+ vm : & VirtualMachine ,
1057
+ ) -> PyResult < PyRef < Self > > {
1058
+ let iter = get_iter ( vm, & iterable) ?;
1059
+ let pool = get_all ( vm, & iter) ?;
1060
+
1061
+ let r = r. as_bigint ( ) ;
1062
+ if r. is_negative ( ) {
1063
+ return Err ( vm. new_value_error ( "r must be non-negative" . to_string ( ) ) ) ;
1064
+ }
1065
+ let r = r. to_usize ( ) . unwrap ( ) ;
1066
+
1067
+ let n = pool. len ( ) ;
1068
+
1069
+ PyItertoolsCombinationsWithReplacement {
1070
+ pool,
1071
+ indices : RefCell :: new ( vec ! [ 0 ; r] ) ,
1072
+ r : Cell :: new ( r) ,
1073
+ exhausted : Cell :: new ( n == 0 && r > 0 ) ,
1074
+ }
1075
+ . into_ref_with_type ( vm, cls)
1076
+ }
1077
+
1078
+ #[ pymethod( name = "__iter__" ) ]
1079
+ fn iter ( zelf : PyRef < Self > , _vm : & VirtualMachine ) -> PyRef < Self > {
1080
+ zelf
1081
+ }
1082
+
1083
+ #[ pymethod( name = "__next__" ) ]
1084
+ fn next ( & self , vm : & VirtualMachine ) -> PyResult {
1085
+ // stop signal
1086
+ if self . exhausted . get ( ) {
1087
+ return Err ( new_stop_iteration ( vm) ) ;
1088
+ }
1089
+
1090
+ let n = self . pool . len ( ) ;
1091
+ let r = self . r . get ( ) ;
1092
+
1093
+ if r == 0 {
1094
+ self . exhausted . set ( true ) ;
1095
+ return Ok ( vm. ctx . new_tuple ( vec ! [ ] ) ) ;
1096
+ }
1097
+
1098
+ let mut indices = self . indices . borrow_mut ( ) ;
1099
+
1100
+ let res = vm
1101
+ . ctx
1102
+ . new_tuple ( indices. iter ( ) . map ( |& i| self . pool [ i] . clone ( ) ) . collect ( ) ) ;
1103
+
1104
+ // Scan indices right-to-left until finding one that is not at its maximum (i + n - r).
1105
+ let mut idx = r as isize - 1 ;
1106
+ while idx >= 0 && indices[ idx as usize ] == n - 1 {
1107
+ idx -= 1 ;
1108
+ }
1109
+
1110
+ // If no suitable index is found, then the indices are all at
1111
+ // their maximum value and we're done.
1112
+ if idx < 0 {
1113
+ self . exhausted . set ( true ) ;
1114
+ } else {
1115
+ let index = indices[ idx as usize ] + 1 ;
1116
+
1117
+ // Increment the current index which we know is not at its
1118
+ // maximum. Then set all to the right to the same value.
1119
+ for j in idx as usize ..r {
1120
+ indices[ j as usize ] = index as usize ;
1121
+ }
1122
+ }
1123
+
1124
+ Ok ( res)
1125
+ }
1126
+ }
1127
+
1034
1128
#[ pyclass]
1035
1129
#[ derive( Debug ) ]
1036
1130
struct PyItertoolsPermutations {
@@ -1257,6 +1351,10 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
1257
1351
let combinations = ctx. new_class ( "combinations" , ctx. object ( ) ) ;
1258
1352
PyItertoolsCombinations :: extend_class ( ctx, & combinations) ;
1259
1353
1354
+ let combinations_with_replacement =
1355
+ ctx. new_class ( "combinations_with_replacement" , ctx. object ( ) ) ;
1356
+ PyItertoolsCombinationsWithReplacement :: extend_class ( ctx, & combinations_with_replacement) ;
1357
+
1260
1358
let count = ctx. new_class ( "count" , ctx. object ( ) ) ;
1261
1359
PyItertoolsCount :: extend_class ( ctx, & count) ;
1262
1360
@@ -1296,6 +1394,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
1296
1394
"chain" => chain,
1297
1395
"compress" => compress,
1298
1396
"combinations" => combinations,
1397
+ "combinations_with_replacement" => combinations_with_replacement,
1299
1398
"count" => count,
1300
1399
"cycle" => cycle,
1301
1400
"dropwhile" => dropwhile,
0 commit comments