Skip to content

Commit fa3fd72

Browse files
committed
Fix array extend and improve
1 parent e6230f3 commit fa3fd72

File tree

2 files changed

+66
-39
lines changed

2 files changed

+66
-39
lines changed

Lib/test/test_array.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,6 @@ def test_reverse(self):
957957
array.array(self.typecode, self.example[::-1])
958958
)
959959

960-
@unittest.skip("TODO: RUSTPYTHON")
961960
def test_extend(self):
962961
a = array.array(self.typecode, self.example)
963962
self.assertRaises(TypeError, a.extend)

vm/src/stdlib/array.rs

Lines changed: 66 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -283,54 +283,58 @@ macro_rules! def_array_enum {
283283
}
284284
}
285285

286-
fn add(&self, other: &ArrayContentType, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
286+
fn add(&self, other: &ArrayContentType, vm: &VirtualMachine) -> PyResult<Self> {
287287
match self {
288288
$(ArrayContentType::$n(v) => if let ArrayContentType::$n(other) = other {
289289
let elements = v.iter().chain(other.iter()).cloned().collect();
290-
let sliced = ArrayContentType::$n(elements);
291-
let obj = PyArray {
292-
array: PyRwLock::new(sliced)
293-
}
294-
.into_object(vm);
295-
Ok(obj)
290+
Ok(ArrayContentType::$n(elements))
296291
} else {
297292
Err(vm.new_type_error("bad argument type for built-in operation".to_owned()))
298293
},)*
299294
}
300295
}
301296

302-
fn iadd(&mut self, other: ArrayContentType, vm: &VirtualMachine) -> PyResult<()> {
297+
fn iadd(&mut self, other: &ArrayContentType, vm: &VirtualMachine) -> PyResult<()> {
303298
match self {
304-
$(ArrayContentType::$n(v) => if let ArrayContentType::$n(mut other) = other {
305-
v.append(&mut other);
299+
$(ArrayContentType::$n(v) => if let ArrayContentType::$n(other) = other {
300+
v.extend(other);
306301
Ok(())
307302
} else {
308303
Err(vm.new_type_error("can only extend with array of same kind".to_owned()))
309304
},)*
310305
}
311306
}
312307

313-
fn mul(&self, counter: isize, vm: &VirtualMachine) -> PyObjectRef {
308+
fn mul(&self, counter: isize) -> Self {
314309
let counter = if counter < 0 { 0 } else { counter as usize };
315310
match self {
316311
$(ArrayContentType::$n(v) => {
317-
let elements = v.iter().cycle().take(v.len() * counter).cloned().collect();
318-
let sliced = ArrayContentType::$n(elements);
319-
PyArray {
320-
array: PyRwLock::new(sliced)
321-
}
322-
.into_object(vm)
312+
let elements = v.repeat(counter);
313+
ArrayContentType::$n(elements)
323314
})*
324315
}
325316
}
326317

327-
fn imul(&mut self, counter: isize) {
328-
let counter = if counter < 0 { 0 } else { counter as usize };
318+
fn clear(&mut self) {
329319
match self {
330-
$(ArrayContentType::$n(v) => {
331-
let mut elements = v.iter().cycle().take(v.len() * counter).cloned().collect();
332-
std::mem::swap(v, &mut elements);
333-
})*
320+
$(ArrayContentType::$n(v) => v.clear(),)*
321+
}
322+
}
323+
324+
fn imul(&mut self, counter: isize) {
325+
if counter <= 0 {
326+
self.clear();
327+
} else if counter != 1 {
328+
let counter = counter as usize;
329+
match self {
330+
$(ArrayContentType::$n(v) => {
331+
let old = v.clone();
332+
v.reserve((counter - 1) * old.len());
333+
for _ in 1..counter {
334+
v.extend(&old);
335+
}
336+
})*
337+
}
334338
}
335339
}
336340

@@ -489,7 +493,7 @@ impl PyArray {
489493
array: PyRwLock::new(array),
490494
};
491495
if let OptionalArg::Present(init) = init {
492-
zelf.extend(init, vm)?;
496+
zelf.extend_from_iterable(init, vm)?;
493497
}
494498
zelf.into_ref_with_type(vm, cls)
495499
}
@@ -525,15 +529,27 @@ impl PyArray {
525529
self.borrow_value_mut().remove(x, vm)
526530
}
527531

528-
#[pymethod]
529-
fn extend(&self, iter: PyIterable, vm: &VirtualMachine) -> PyResult<()> {
532+
fn extend_from_iterable(&self, iter: PyIterable, vm: &VirtualMachine) -> PyResult<()> {
530533
let mut array = self.borrow_value_mut();
531-
for elem in iter.iter(vm)? {
532-
array.push(elem?, vm)?;
534+
for obj in iter.iter(vm)? {
535+
array.push(obj?, vm)?;
533536
}
534537
Ok(())
535538
}
536539

540+
#[pymethod]
541+
fn extend(zelf: PyRef<Self>, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
542+
if zelf.is(&obj) {
543+
zelf.borrow_value_mut().imul(2);
544+
Ok(())
545+
} else if let Some(array) = obj.payload::<PyArray>() {
546+
zelf.borrow_value_mut().iadd(&*array.borrow_value(), vm)
547+
} else {
548+
let iter = PyIterable::try_from_object(vm, obj)?;
549+
zelf.extend_from_iterable(iter, vm)
550+
}
551+
}
552+
537553
#[pymethod]
538554
fn frombytes(&self, b: PyBytesRef, vm: &VirtualMachine) -> PyResult<()> {
539555
let b = b.borrow_value();
@@ -670,9 +686,16 @@ impl PyArray {
670686
}
671687

672688
#[pymethod(name = "__add__")]
673-
fn add(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
689+
fn add(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
674690
if let Some(other) = other.payload::<PyArray>() {
675-
self.borrow_value().add(&*other.borrow_value(), vm)
691+
self.borrow_value()
692+
.add(&*other.borrow_value(), vm)
693+
.map(|array| {
694+
PyArray {
695+
array: PyRwLock::new(array),
696+
}
697+
.into_ref(vm)
698+
})
676699
} else {
677700
Err(vm.new_type_error(format!(
678701
"can only append array (not \"{}\") to array",
@@ -682,11 +705,13 @@ impl PyArray {
682705
}
683706

684707
#[pymethod(name = "__iadd__")]
685-
fn iadd(zelf: PyRef<Self>, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
686-
if let Some(other) = other.payload::<PyArray>() {
687-
let other = other.borrow_value().clone();
688-
let result = zelf.borrow_value_mut().iadd(other, vm);
689-
result.map(|_| zelf.into_object())
708+
fn iadd(zelf: PyRef<Self>, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
709+
if zelf.is(&other) {
710+
zelf.borrow_value_mut().imul(2);
711+
Ok(zelf)
712+
} else if let Some(other) = other.payload::<PyArray>() {
713+
let result = zelf.borrow_value_mut().iadd(&*other.borrow_value(), vm);
714+
result.map(|_| zelf)
690715
} else {
691716
Err(vm.new_type_error(format!(
692717
"can only extend array with array (not \"{}\")",
@@ -696,12 +721,15 @@ impl PyArray {
696721
}
697722

698723
#[pymethod(name = "__mul__")]
699-
fn mul(&self, counter: isize, vm: &VirtualMachine) -> PyObjectRef {
700-
self.borrow_value().mul(counter, vm)
724+
fn mul(&self, counter: isize, vm: &VirtualMachine) -> PyRef<Self> {
725+
PyArray {
726+
array: PyRwLock::new(self.borrow_value().mul(counter)),
727+
}
728+
.into_ref(vm)
701729
}
702730

703731
#[pymethod(name = "__rmul__")]
704-
fn rmul(&self, counter: isize, vm: &VirtualMachine) -> PyObjectRef {
732+
fn rmul(&self, counter: isize, vm: &VirtualMachine) -> PyRef<Self> {
705733
self.mul(counter, &vm)
706734
}
707735

0 commit comments

Comments
 (0)