Skip to content

Commit 7e10671

Browse files
committed
Add list.extend method
1 parent e6fd9fb commit 7e10671

File tree

4 files changed

+44
-23
lines changed

4 files changed

+44
-23
lines changed

tests/snippets/list.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@
66
y = [2, *x]
77
assert y == [2, 1, 2, 3]
88

9+
y.extend(x)
10+
assert y == [2, 1, 2, 3, 1, 2, 3]
11+

vm/src/frame.rs

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ use super::obj::objdict;
1313
use super::obj::objiter;
1414
use super::obj::objlist;
1515
use super::obj::objstr;
16-
use super::obj::objtuple;
1716
use super::obj::objtype;
1817
use super::pyobject::{
1918
AttributeProtocol, DictProtocol, IdProtocol, ParentProtocol, PyFuncArgs, PyObject,
@@ -482,7 +481,7 @@ impl Frame {
482481
vec![]
483482
};
484483
let args = self.pop_value();
485-
let args = self.extract_elements(vm, &args)?;
484+
let args = vm.extract_elements(&args)?;
486485
PyFuncArgs { args, kwargs }
487486
}
488487
};
@@ -595,7 +594,7 @@ impl Frame {
595594
}
596595
bytecode::Instruction::UnpackSequence { size } => {
597596
let value = self.pop_value();
598-
let elements = self.extract_elements(vm, &value)?;
597+
let elements = vm.extract_elements(&value)?;
599598
if elements.len() != *size {
600599
Err(vm.new_value_error("Wrong number of values to unpack".to_string()))
601600
} else {
@@ -607,7 +606,7 @@ impl Frame {
607606
}
608607
bytecode::Instruction::UnpackEx { before, after } => {
609608
let value = self.pop_value();
610-
let elements = self.extract_elements(vm, &value)?;
609+
let elements = vm.extract_elements(&value)?;
611610
let min_expected = *before + *after;
612611
if elements.len() < min_expected {
613612
Err(vm.new_value_error(format!(
@@ -642,7 +641,7 @@ impl Frame {
642641
}
643642
bytecode::Instruction::Unpack => {
644643
let value = self.pop_value();
645-
let elements = self.extract_elements(vm, &value)?;
644+
let elements = vm.extract_elements(&value)?;
646645
for element in elements.into_iter().rev() {
647646
self.push_value(element);
648647
}
@@ -651,23 +650,6 @@ impl Frame {
651650
}
652651
}
653652

654-
fn extract_elements(
655-
&mut self,
656-
vm: &mut VirtualMachine,
657-
value: &PyObjectRef,
658-
) -> Result<Vec<PyObjectRef>, PyObjectRef> {
659-
// Extract elements from item, if possible:
660-
let elements = if objtype::isinstance(value, &vm.ctx.tuple_type()) {
661-
objtuple::get_elements(value)
662-
} else if objtype::isinstance(value, &vm.ctx.list_type()) {
663-
objlist::get_elements(value)
664-
} else {
665-
let iter = objiter::get_iter(vm, value)?;
666-
objiter::get_all(vm, &iter)?
667-
};
668-
Ok(elements)
669-
}
670-
671653
fn get_elements(
672654
&mut self,
673655
vm: &mut VirtualMachine,
@@ -678,7 +660,7 @@ impl Frame {
678660
if unpack {
679661
let mut result: Vec<PyObjectRef> = vec![];
680662
for element in elements {
681-
let expanded = self.extract_elements(vm, &element)?;
663+
let expanded = vm.extract_elements(&element)?;
682664
for inner in expanded {
683665
result.push(inner);
684666
}

vm/src/obj/objlist.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,22 @@ fn list_clear(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
137137
}
138138
}
139139

140+
pub fn list_extend(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
141+
arg_check!(
142+
vm,
143+
args,
144+
required = [(list, Some(vm.ctx.list_type())), (x, None)]
145+
);
146+
let mut new_elements = vm.extract_elements(x)?;
147+
let mut list_obj = list.borrow_mut();
148+
if let PyObjectKind::List { ref mut elements } = list_obj.kind {
149+
elements.append(&mut new_elements);
150+
Ok(vm.get_none())
151+
} else {
152+
Err(vm.new_type_error("list.extend is called with no list".to_string()))
153+
}
154+
}
155+
140156
fn list_len(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
141157
trace!("list.len called with: {:?}", args);
142158
arg_check!(vm, args, required = [(list, Some(vm.ctx.list_type()))]);
@@ -198,5 +214,6 @@ pub fn init(context: &PyContext) {
198214
list_type.set_attr("__repr__", context.new_rustfunc(list_repr));
199215
list_type.set_attr("append", context.new_rustfunc(list_append));
200216
list_type.set_attr("clear", context.new_rustfunc(list_clear));
217+
list_type.set_attr("extend", context.new_rustfunc(list_extend));
201218
list_type.set_attr("reverse", context.new_rustfunc(list_reverse));
202219
}

vm/src/vm.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ use super::builtins;
1212
use super::bytecode;
1313
use super::frame::{copy_code, Frame};
1414
use super::obj::objgenerator;
15+
use super::obj::objiter;
16+
use super::obj::objlist;
1517
use super::obj::objobject;
18+
use super::obj::objtuple;
1619
use super::obj::objtype;
1720
use super::pyobject::{DictProtocol, PyContext, PyFuncArgs, PyObjectKind, PyObjectRef, PyResult};
1821
use super::stdlib;
@@ -351,6 +354,22 @@ impl VirtualMachine {
351354
Ok(())
352355
}
353356

357+
pub fn extract_elements(
358+
&mut self,
359+
value: &PyObjectRef,
360+
) -> Result<Vec<PyObjectRef>, PyObjectRef> {
361+
// Extract elements from item, if possible:
362+
let elements = if objtype::isinstance(value, &self.ctx.tuple_type()) {
363+
objtuple::get_elements(value)
364+
} else if objtype::isinstance(value, &self.ctx.list_type()) {
365+
objlist::get_elements(value)
366+
} else {
367+
let iter = objiter::get_iter(self, value)?;
368+
objiter::get_all(self, &iter)?
369+
};
370+
Ok(elements)
371+
}
372+
354373
pub fn get_attribute(&mut self, obj: PyObjectRef, attr_name: &str) -> PyResult {
355374
objtype::get_attribute(self, obj.clone(), attr_name)
356375
}

0 commit comments

Comments
 (0)