Skip to content

Commit cc25db9

Browse files
authored
Merge pull request RustPython#2170 from qingshi163/dev
Implement array setitem by slice
2 parents 446bf76 + 5a21bc2 commit cc25db9

File tree

3 files changed

+157
-21
lines changed

3 files changed

+157
-21
lines changed

Lib/test/test_array.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ class ArraySubclassWithKwargs(array.array):
2525
def __init__(self, typecode, newarg=None):
2626
array.array.__init__(self)
2727

28-
typecodes = 'ubBhHiIlLfdqQ'
28+
# TODO: RUSTPYTHON
29+
# We did not support typecode u for unicode yet
30+
# typecodes = 'ubBhHiIlLfdqQ'
31+
typecodes = 'bBhHiIlLfdqQ'
2932

3033
class MiscTest(unittest.TestCase):
3134

@@ -799,8 +802,6 @@ def test_extended_getslice(self):
799802
self.assertEqual(list(a[start:stop:step]),
800803
list(a)[start:stop:step])
801804

802-
# TODO: RUSTPYTHON
803-
@unittest.expectedFailure
804805
def test_setslice(self):
805806
a = array.array(self.typecode, self.example)
806807
a[:1] = a
@@ -1233,8 +1234,6 @@ def test_delslice(self):
12331234
a = array.array(self.typecode, range(10))
12341235
del a[9::1<<333]
12351236

1236-
# TODO: RUSTPYTHON
1237-
@unittest.expectedFailure
12381237
def test_assignment(self):
12391238
a = array.array(self.typecode, range(10))
12401239
a[::2] = array.array(self.typecode, [42]*5)

tests/snippets/stdlib_array.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,18 @@
1919
b = a
2020
assert a.__ne__(b) is False
2121
b = array("B", [3, 2, 1, 0])
22-
assert a.__ne__(b) is True
22+
assert a.__ne__(b) is True
23+
24+
# slice assignment step overflow behaviour test
25+
T = 'I'
26+
a = array(T, range(10))
27+
b = array(T, [100])
28+
a[::9999999999] = b
29+
assert a == array(T, [100, 1, 2, 3, 4, 5, 6, 7, 8, 9])
30+
a[::-9999999999] = b
31+
assert a == array(T, [100, 1, 2, 3, 4, 5, 6, 7, 8, 100])
32+
c = array(T)
33+
a[0:0:9999999999] = c
34+
assert a == array(T, [100, 1, 2, 3, 4, 5, 6, 7, 8, 100])
35+
a[0:0:-9999999999] = c
36+
assert a == array(T, [100, 1, 2, 3, 4, 5, 6, 7, 8, 100])

vm/src/stdlib/array.rs

Lines changed: 138 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,21 @@ use crate::common::cell::{
44
};
55
use crate::function::OptionalArg;
66
use crate::obj::objbytes::PyBytesRef;
7-
use crate::obj::objsequence::PySliceableSequence;
7+
use crate::obj::objsequence::{get_slice_range, PySliceableSequence};
88
use crate::obj::objslice::PySliceRef;
99
use crate::obj::objstr::PyStringRef;
1010
use crate::obj::objtype::PyClassRef;
1111
use crate::obj::{objbool, objiter};
1212
use crate::pyobject::{
13-
BorrowValue, Either, IntoPyObject, PyArithmaticValue, PyClassImpl, PyComparisonValue,
14-
PyIterable, PyObject, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject,
13+
BorrowValue, Either, IdProtocol, IntoPyObject, PyArithmaticValue, PyClassImpl,
14+
PyComparisonValue, PyIterable, PyObject, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject,
15+
TypeProtocol,
1516
};
1617
use crate::VirtualMachine;
1718
use crossbeam_utils::atomic::AtomicCell;
1819
use itertools::Itertools;
20+
use num_bigint::BigInt;
21+
use num_traits::{One, Signed, ToPrimitive, Zero};
1922
use std::fmt;
2023

2124
struct ArrayTypeSpecifierError {
@@ -33,7 +36,7 @@ impl fmt::Display for ArrayTypeSpecifierError {
3336

3437
macro_rules! def_array_enum {
3538
($(($n:ident, $t:ident, $c:literal)),*$(,)?) => {
36-
#[derive(Debug)]
39+
#[derive(Debug, Clone)]
3740
enum ArrayContentType {
3841
$($n(Vec<$t>),)*
3942
}
@@ -223,19 +226,115 @@ macro_rules! def_array_enum {
223226
}
224227
}
225228

226-
fn setitem(&mut self, needle: Either<isize, PySliceRef>, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
227-
match needle {
228-
Either::A(i) => {
229-
let i = self.idx(i, "array assignment", vm)?;
230-
match self {
231-
$(ArrayContentType::$n(v) => { v[i] = TryFromObject::try_from_object(vm, value)? },)*
229+
fn setitem_by_slice(&mut self, slice: PySliceRef, items: &ArrayContentType, vm: &VirtualMachine) -> PyResult<()> {
230+
let start = slice.start_index(vm)?;
231+
let stop = slice.stop_index(vm)?;
232+
let step = slice.step_index(vm)?.unwrap_or_else(BigInt::one);
233+
234+
if step.is_zero() {
235+
return Err(vm.new_value_error("slice step cannot be zero".to_owned()));
236+
}
237+
238+
match self {
239+
$(ArrayContentType::$n(elements) => if let ArrayContentType::$n(items) = items {
240+
if step == BigInt::one() {
241+
let range = get_slice_range(&start, &stop, elements.len());
242+
let range = if range.end < range.start {
243+
range.start..range.start
244+
} else {
245+
range
246+
};
247+
elements.splice(range, items.iter().cloned());
248+
return Ok(());
232249
}
233-
Ok(())
234-
}
235-
Either::B(_slice) => Err(vm.new_not_implemented_error("array slice is not implemented".to_owned())),
250+
251+
let (start, stop, step, is_negative_step) = if step.is_negative() {
252+
(
253+
stop.map(|x| if x == -BigInt::one() {
254+
elements.len() + BigInt::one()
255+
} else {
256+
x + 1
257+
}),
258+
start.map(|x| if x == -BigInt::one() {
259+
BigInt::from(elements.len())
260+
} else {
261+
x + 1
262+
}),
263+
-step,
264+
true
265+
)
266+
} else {
267+
(start, stop, step, false)
268+
};
269+
270+
let range = get_slice_range(&start, &stop, elements.len());
271+
let range = if range.end < range.start {
272+
range.start..range.start
273+
} else {
274+
range
275+
};
276+
277+
// step is not negative here
278+
if let Some(step) = step.to_usize() {
279+
let slicelen = if range.end > range.start {
280+
(range.end - range.start - 1) / step + 1
281+
} else {
282+
0
283+
};
284+
285+
if slicelen == items.len() {
286+
if is_negative_step {
287+
for (i, &item) in range.rev().step_by(step).zip(items) {
288+
elements[i] = item;
289+
}
290+
} else {
291+
for (i, &item) in range.step_by(step).zip(items) {
292+
elements[i] = item;
293+
}
294+
}
295+
Ok(())
296+
} else {
297+
Err(vm.new_value_error(format!(
298+
"attempt to assign sequence of size {} to extended slice of size {}",
299+
items.len(), slicelen
300+
)))
301+
}
302+
} else {
303+
// edge case, step is too big for usize
304+
// same behaviour as CPython
305+
let slicelen = if range.start < range.end { 1 } else { 0 };
306+
if match items.len() {
307+
0 => slicelen == 0,
308+
1 => {
309+
elements[
310+
if is_negative_step { range.end - 1 } else { range.start }
311+
] = items[0];
312+
true
313+
},
314+
_ => false,
315+
} {
316+
Ok(())
317+
} else {
318+
Err(vm.new_value_error(format!(
319+
"attempt to assign sequence of size {} to extended slice of size {}",
320+
items.len(), slicelen
321+
)))
322+
}
323+
}
324+
} else {
325+
Err(vm.new_type_error("bad argument type for built-in operation".to_owned()))
326+
},)*
236327
}
237328
}
238329

330+
fn setitem_by_idx(&mut self, i: isize, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
331+
let i = self.idx(i, "array assignment", vm)?;
332+
match self {
333+
$(ArrayContentType::$n(v) => { v[i] = TryFromObject::try_from_object(vm, value)? },)*
334+
}
335+
Ok(())
336+
}
337+
239338
fn repr(&self, _vm: &VirtualMachine) -> PyResult<String> {
240339
// we don't need ReprGuard here
241340
let s = match self {
@@ -449,12 +548,36 @@ impl PyArray {
449548

450549
#[pymethod(magic)]
451550
fn setitem(
452-
&self,
551+
zelf: PyRef<Self>,
453552
needle: Either<isize, PySliceRef>,
454553
obj: PyObjectRef,
455554
vm: &VirtualMachine,
456555
) -> PyResult<()> {
457-
self.borrow_value_mut().setitem(needle, obj, vm)
556+
match needle {
557+
Either::A(i) => zelf.borrow_value_mut().setitem_by_idx(i, obj, vm),
558+
Either::B(slice) => {
559+
let cloned;
560+
let guard;
561+
let items = if zelf.is(&obj) {
562+
cloned = zelf.borrow_value().clone();
563+
&cloned
564+
} else {
565+
match obj.payload::<PyArray>() {
566+
Some(array) => {
567+
guard = array.borrow_value();
568+
&*guard
569+
}
570+
None => {
571+
return Err(vm.new_type_error(format!(
572+
"can only assign array (not \"{}\") to array slice",
573+
obj.class().name
574+
)));
575+
}
576+
}
577+
};
578+
zelf.borrow_value_mut().setitem_by_slice(slice, items, vm)
579+
}
580+
}
458581
}
459582

460583
#[pymethod(name = "__repr__")]

0 commit comments

Comments
 (0)