Skip to content
This repository was archived by the owner on Apr 2, 2020. It is now read-only.

Commit 039b5bb

Browse files
committed
Implement dict __setitem__
1 parent 9869ef7 commit 039b5bb

File tree

7 files changed

+75
-56
lines changed

7 files changed

+75
-56
lines changed

tests/snippets/comprehensions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
# TODO: how to check set equality?
1616
# assert v == {2, 6, 4}
1717

18-
# TODO:
19-
#u = {str(b): b-2 for b in x}
18+
u = {str(b): b-2 for b in x}
19+
assert u['3'] == 1
20+
assert u['1'] == -1
2021

2122
y = [a+2 for a in x if a % 2]
2223
print(y)

vm/src/builtins.rs

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,8 @@ fn builtin_bin(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
9595
Ok(vm.new_str(s))
9696
}
9797

98-
// builtin_bool
9998
// builtin_breakpoint
10099
// builtin_bytearray
101-
// builtin_bytes
102100
// builtin_callable
103101

104102
fn builtin_chr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
@@ -219,7 +217,6 @@ fn builtin_exec(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
219217
}
220218

221219
// builtin_filter
222-
// builtin_float
223220
// builtin_format
224221
// builtin_frozenset
225222

@@ -275,7 +272,6 @@ fn builtin_id(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
275272
}
276273

277274
// builtin_input
278-
// builtin_int
279275

280276
fn builtin_isinstance(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
281277
arg_check!(vm, args, required = [(obj, None), (typ, None)]);
@@ -325,8 +321,6 @@ fn builtin_len(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
325321
}
326322
}
327323

328-
// builtin_list
329-
330324
fn builtin_locals(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
331325
arg_check!(vm, args);
332326
Ok(vm.get_locals())
@@ -412,7 +406,6 @@ fn builtin_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
412406
}
413407
}
414408

415-
// builtin_object
416409
// builtin_oct
417410
// builtin_open
418411

@@ -498,14 +491,12 @@ fn builtin_range(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
498491
}
499492
}
500493

501-
// builtin_repr
502494
fn builtin_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
503495
arg_check!(vm, args, required = [(obj, None)]);
504496
vm.to_repr(obj)
505497
}
506498
// builtin_reversed
507499
// builtin_round
508-
// builtin_set
509500

510501
fn builtin_setattr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
511502
arg_check!(
@@ -570,12 +561,10 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef {
570561
dict.insert(String::from("list"), ctx.list_type());
571562
dict.insert(String::from("locals"), ctx.new_rustfunc(builtin_locals));
572563
dict.insert(String::from("map"), ctx.new_rustfunc(builtin_map));
573-
574564
dict.insert(String::from("max"), ctx.new_rustfunc(builtin_max));
575565
dict.insert(String::from("min"), ctx.new_rustfunc(builtin_min));
576-
566+
dict.insert(String::from("object"), ctx.object());
577567
dict.insert(String::from("ord"), ctx.new_rustfunc(builtin_ord));
578-
579568
dict.insert(String::from("next"), ctx.new_rustfunc(builtin_next));
580569
dict.insert(String::from("pow"), ctx.new_rustfunc(builtin_pow));
581570
dict.insert(String::from("print"), ctx.new_rustfunc(builtin_print));
@@ -586,7 +575,6 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef {
586575
dict.insert(String::from("str"), ctx.str_type());
587576
dict.insert(String::from("tuple"), ctx.tuple_type());
588577
dict.insert(String::from("type"), ctx.type_type());
589-
dict.insert(String::from("object"), ctx.object());
590578

591579
// Exceptions:
592580
dict.insert(

vm/src/frame.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,14 +247,14 @@ impl Frame {
247247
Ok(None)
248248
}
249249
bytecode::Instruction::BuildMap { size, unpack } => {
250-
let mut elements = HashMap::new();
250+
let mut elements: HashMap<String, PyObjectRef> = HashMap::new();
251251
for _x in 0..*size {
252252
let obj = self.pop_value();
253253
if *unpack {
254254
// Take all key-value pairs from the dict:
255255
let dict_elements = objdict::get_elements(&obj);
256-
for (key, obj) in dict_elements {
257-
elements.insert(key, obj);
256+
for (key, obj) in dict_elements.iter() {
257+
elements.insert(key.clone(), obj.clone());
258258
}
259259
} else {
260260
// XXX: Currently, we only support String keys, so we have to unwrap the
@@ -471,7 +471,8 @@ impl Frame {
471471
bytecode::CallType::Ex(has_kwargs) => {
472472
let kwargs = if *has_kwargs {
473473
let kw_dict = self.pop_value();
474-
objdict::get_elements(&kw_dict).into_iter().collect()
474+
let dict_elements = objdict::get_elements(&kw_dict).clone();
475+
dict_elements.into_iter().collect()
475476
} else {
476477
vec![]
477478
};

vm/src/obj/objdict.rs

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,9 @@ use super::super::vm::VirtualMachine;
66
use super::objstr;
77
use super::objtype;
88
use num_bigint::ToBigInt;
9+
use std::cell::{Ref, RefMut};
910
use std::collections::HashMap;
10-
11-
pub fn _set_item(
12-
vm: &mut VirtualMachine,
13-
_d: PyObjectRef,
14-
_idx: PyObjectRef,
15-
_obj: PyObjectRef,
16-
) -> PyResult {
17-
// TODO: Implement objdict::set_item
18-
Ok(vm.get_none())
19-
}
11+
use std::ops::{Deref, DerefMut};
2012

2113
pub fn new(dict_type: PyObjectRef) -> PyObjectRef {
2214
PyObject::new(
@@ -27,12 +19,28 @@ pub fn new(dict_type: PyObjectRef) -> PyObjectRef {
2719
)
2820
}
2921

30-
pub fn get_elements(obj: &PyObjectRef) -> HashMap<String, PyObjectRef> {
31-
if let PyObjectKind::Dict { elements } = &obj.borrow().kind {
32-
elements.clone()
33-
} else {
34-
panic!("Cannot extract dict elements");
35-
}
22+
pub fn get_elements<'a>(
23+
obj: &'a PyObjectRef,
24+
) -> impl Deref<Target = HashMap<String, PyObjectRef>> + 'a {
25+
Ref::map(obj.borrow(), |py_obj| {
26+
if let PyObjectKind::Dict { ref elements } = py_obj.kind {
27+
elements
28+
} else {
29+
panic!("Cannot extract dict elements");
30+
}
31+
})
32+
}
33+
34+
fn get_mut_elements<'a>(
35+
obj: &'a PyObjectRef,
36+
) -> impl DerefMut<Target = HashMap<String, PyObjectRef>> + 'a {
37+
RefMut::map(obj.borrow_mut(), |py_obj| {
38+
if let PyObjectKind::Dict { ref mut elements } = py_obj.kind {
39+
elements
40+
} else {
41+
panic!("Cannot extract dict elements");
42+
}
43+
})
3644
}
3745

3846
fn dict_new(_vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
@@ -50,13 +58,13 @@ fn dict_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
5058

5159
let elements = get_elements(o);
5260
let mut str_parts = vec![];
53-
for elem in elements {
61+
for elem in elements.iter() {
5462
let s = vm.to_repr(&elem.1)?;
5563
let value_str = objstr::get_value(&s);
5664
str_parts.push(format!("{}: {}", elem.0, value_str));
5765
}
5866

59-
let s = format!("{{ {} }}", str_parts.join(", "));
67+
let s = format!("{{{}}}", str_parts.join(", "));
6068
Ok(vm.new_str(s))
6169
}
6270

@@ -80,7 +88,7 @@ pub fn dict_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
8088
Ok(vm.new_bool(false))
8189
}
8290

83-
pub fn dict_delitem(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
91+
fn dict_delitem(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
8492
arg_check!(
8593
vm,
8694
args,
@@ -94,18 +102,34 @@ pub fn dict_delitem(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
94102
let needle = objstr::get_value(&needle);
95103

96104
// Delete the item:
97-
let mut dict_obj = dict.borrow_mut();
98-
if let PyObjectKind::Dict { ref mut elements } = dict_obj.kind {
99-
match elements.remove(&needle) {
100-
Some(_) => Ok(vm.get_none()),
101-
None => Err(vm.new_value_error(format!("Key not found: {}", needle))),
102-
}
103-
} else {
104-
panic!("Cannot extract dict elements");
105+
let mut elements = get_mut_elements(dict);
106+
match elements.remove(&needle) {
107+
Some(_) => Ok(vm.get_none()),
108+
None => Err(vm.new_value_error(format!("Key not found: {}", needle))),
105109
}
106110
}
107111

108-
pub fn dict_getitem(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
112+
fn dict_setitem(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
113+
arg_check!(
114+
vm,
115+
args,
116+
required = [
117+
(dict, Some(vm.ctx.dict_type())),
118+
(needle, Some(vm.ctx.str_type())),
119+
(value, None)
120+
]
121+
);
122+
123+
// What we are looking for:
124+
let needle = objstr::get_value(&needle);
125+
126+
// Delete the item:
127+
let mut elements = get_mut_elements(dict);
128+
elements.insert(needle, value.clone());
129+
Ok(vm.get_none())
130+
}
131+
132+
fn dict_getitem(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
109133
arg_check!(
110134
vm,
111135
args,
@@ -143,4 +167,5 @@ pub fn init(context: &PyContext) {
143167
dict_type.set_attr("__getitem__", context.new_rustfunc(dict_getitem));
144168
dict_type.set_attr("__new__", context.new_rustfunc(dict_new));
145169
dict_type.set_attr("__repr__", context.new_rustfunc(dict_repr));
170+
dict_type.set_attr("__setitem__", context.new_rustfunc(dict_setitem));
146171
}

vm/src/obj/objset.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,17 @@ fn set_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
9898
arg_check!(vm, args, required = [(o, Some(vm.ctx.set_type()))]);
9999

100100
let elements = get_elements(o);
101-
let mut str_parts = vec![];
102-
for elem in elements.values() {
103-
let part = vm.to_repr(elem)?;
104-
str_parts.push(objstr::get_value(&part));
105-
}
101+
let s = if elements.len() == 0 {
102+
"set()".to_string()
103+
} else {
104+
let mut str_parts = vec![];
105+
for elem in elements.values() {
106+
let part = vm.to_repr(elem)?;
107+
str_parts.push(objstr::get_value(&part));
108+
}
106109

107-
let s = format!("{{ {} }}", str_parts.join(", "));
110+
format!("{{{}}}", str_parts.join(", "))
111+
};
108112
Ok(vm.new_str(s))
109113
}
110114

vm/src/obj/objtype.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ pub fn get_attributes(obj: &PyObjectRef) -> HashMap<String, PyObjectRef> {
209209
} = &bc.borrow().kind
210210
{
211211
let elements = objdict::get_elements(dict);
212-
for (name, value) in elements {
212+
for (name, value) in elements.iter() {
213213
attributes.insert(name.to_string(), value.clone());
214214
}
215215
}
@@ -218,7 +218,7 @@ pub fn get_attributes(obj: &PyObjectRef) -> HashMap<String, PyObjectRef> {
218218
// Get instance attributes:
219219
if let PyObjectKind::Instance { dict } = &obj.borrow().kind {
220220
let elements = objdict::get_elements(dict);
221-
for (name, value) in elements {
221+
for (name, value) in elements.iter() {
222222
attributes.insert(name.to_string(), value.clone());
223223
}
224224
}

vm/src/stdlib/json.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ impl<'s> serde::Serialize for PyObjectSerializer<'s> {
6262
} else if objtype::isinstance(self.pyobject, &self.ctx.dict_type()) {
6363
let elements = objdict::get_elements(self.pyobject);
6464
let mut map = serializer.serialize_map(Some(elements.len()))?;
65-
for (key, e) in elements {
65+
for (key, e) in elements.iter() {
6666
map.serialize_entry(&key, &self.clone_with_object(&e))?;
6767
}
6868
map.end()

0 commit comments

Comments
 (0)