Skip to content

Commit 0f6a25c

Browse files
Merge pull request RustPython#791 from RustPython/dict_protocols
Dict protocols
2 parents a6858be + 584b707 commit 0f6a25c

File tree

13 files changed

+133
-83
lines changed

13 files changed

+133
-83
lines changed

tests/snippets/dict.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,27 @@ def dict_eq(d1, d2):
5757
assert x.get("here", "default") == "here"
5858
assert x.get("not here") == None
5959

60+
class LengthDict(dict):
61+
def __getitem__(self, k):
62+
return len(k)
63+
64+
x = LengthDict()
65+
assert type(x) == LengthDict
66+
assert x['word'] == 4
67+
assert x.get('word') is None
68+
69+
assert 5 == eval("a + word", LengthDict())
70+
71+
72+
class Squares(dict):
73+
def __missing__(self, k):
74+
v = k * k
75+
self[k] = v
76+
return v
77+
78+
x = Squares()
79+
assert x[-5] == 25
80+
6081
# An object that hashes to the same value always, and compares equal if any its values match.
6182
class Hashable(object):
6283
def __init__(self, *args):

vm/src/builtins.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use crate::obj::objtype::{self, PyClassRef};
2020
use crate::frame::Scope;
2121
use crate::function::{Args, OptionalArg, PyFuncArgs};
2222
use crate::pyobject::{
23-
DictProtocol, IdProtocol, PyIterable, PyObjectRef, PyResult, PyValue, TryFromObject,
23+
IdProtocol, ItemProtocol, PyIterable, PyObjectRef, PyResult, PyValue, TryFromObject,
2424
TypeProtocol,
2525
};
2626
use crate::vm::VirtualMachine;
@@ -804,6 +804,6 @@ pub fn builtin_build_class_(vm: &VirtualMachine, mut args: PyFuncArgs) -> PyResu
804804
"__call__",
805805
vec![name_arg, bases, namespace.into_object()],
806806
)?;
807-
cells.set_item("__class__", class.clone(), vm);
807+
cells.set_item("__class__", class.clone(), vm)?;
808808
Ok(class)
809809
}

vm/src/dictdatatype.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,15 @@ impl Dict {
7676
}
7777

7878
/// Retrieve a key
79-
pub fn get(&self, vm: &VirtualMachine, key: &PyObjectRef) -> PyResult {
79+
pub fn get(&self, vm: &VirtualMachine, key: &PyObjectRef) -> PyResult<Option<PyObjectRef>> {
8080
if let LookupResult::Existing(index) = self.lookup(vm, key)? {
8181
if let Some(entry) = &self.entries[index] {
82-
Ok(entry.value.clone())
82+
Ok(Some(entry.value.clone()))
8383
} else {
8484
panic!("Lookup returned invalid index into entries!");
8585
}
8686
} else {
87-
let key_repr = vm.to_pystr(key)?;
88-
Err(vm.new_key_error(format!("Key not found: {}", key_repr)))
87+
Ok(None)
8988
}
9089
}
9190

vm/src/frame.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ use crate::obj::objstr;
2121
use crate::obj::objtype;
2222
use crate::obj::objtype::PyClassRef;
2323
use crate::pyobject::{
24-
DictProtocol, IdProtocol, ItemProtocol, PyContext, PyObjectRef, PyRef, PyResult, PyValue,
25-
TryFromObject, TypeProtocol,
24+
IdProtocol, ItemProtocol, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject,
25+
TypeProtocol,
2626
};
2727
use crate::vm::VirtualMachine;
2828
use itertools::Itertools;
@@ -133,12 +133,12 @@ pub trait NameProtocol {
133133
impl NameProtocol for Scope {
134134
fn load_name(&self, vm: &VirtualMachine, name: &str) -> Option<PyObjectRef> {
135135
for dict in self.locals.iter() {
136-
if let Some(value) = dict.get_item(name, vm) {
136+
if let Some(value) = dict.get_item_option(name, vm).unwrap() {
137137
return Some(value);
138138
}
139139
}
140140

141-
if let Some(value) = self.globals.get_item(name, vm) {
141+
if let Some(value) = self.globals.get_item_option(name, vm).unwrap() {
142142
return Some(value);
143143
}
144144

@@ -147,19 +147,19 @@ impl NameProtocol for Scope {
147147

148148
fn load_cell(&self, vm: &VirtualMachine, name: &str) -> Option<PyObjectRef> {
149149
for dict in self.locals.iter().skip(1) {
150-
if let Some(value) = dict.get_item(name, vm) {
150+
if let Some(value) = dict.get_item_option(name, vm).unwrap() {
151151
return Some(value);
152152
}
153153
}
154154
None
155155
}
156156

157157
fn store_name(&self, vm: &VirtualMachine, key: &str, value: PyObjectRef) {
158-
self.get_locals().set_item(key, value, vm)
158+
self.get_locals().set_item(key, value, vm).unwrap();
159159
}
160160

161161
fn delete_name(&self, vm: &VirtualMachine, key: &str) {
162-
self.get_locals().del_item(key, vm)
162+
self.get_locals().del_item(key, vm).unwrap();
163163
}
164164
}
165165

@@ -394,12 +394,12 @@ impl Frame {
394394
obj.downcast().expect("Need a dictionary to build a map.");
395395
let dict_elements = dict.get_key_value_pairs();
396396
for (key, value) in dict_elements.iter() {
397-
map_obj.set_item(key.clone(), value.clone(), vm);
397+
map_obj.set_item(key.clone(), value.clone(), vm).unwrap();
398398
}
399399
}
400400
} else {
401401
for (key, value) in self.pop_multiple(2 * size).into_iter().tuples() {
402-
map_obj.set_item(key, value, vm)
402+
map_obj.set_item(key, value, vm).unwrap();
403403
}
404404
}
405405

vm/src/import.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use std::path::PathBuf;
88
use crate::compile;
99
use crate::frame::Scope;
1010
use crate::obj::{objsequence, objstr};
11-
use crate::pyobject::{DictProtocol, ItemProtocol, PyResult};
11+
use crate::pyobject::{ItemProtocol, PyResult};
1212
use crate::util;
1313
use crate::vm::VirtualMachine;
1414

@@ -39,7 +39,7 @@ fn import_uncached_module(vm: &VirtualMachine, current_path: PathBuf, module: &s
3939
// trace!("Code object: {:?}", code_obj);
4040

4141
let attrs = vm.ctx.new_dict();
42-
attrs.set_item("__name__", vm.new_str(module.to_string()), vm);
42+
attrs.set_item("__name__", vm.new_str(module.to_string()), vm)?;
4343
vm.run_code_obj(code_obj, Scope::new(None, attrs.clone()))?;
4444
Ok(vm.ctx.new_module(module, attrs))
4545
}

vm/src/obj/objdict.rs

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ use std::fmt;
33

44
use crate::function::{KwArgs, OptionalArg};
55
use crate::pyobject::{
6-
DictProtocol, IntoPyObject, ItemProtocol, PyAttributes, PyContext, PyObjectRef, PyRef,
7-
PyResult, PyValue,
6+
IntoPyObject, ItemProtocol, PyAttributes, PyContext, PyObjectRef, PyRef, PyResult, PyValue,
87
};
98
use crate::vm::{ReprGuard, VirtualMachine};
109

@@ -40,17 +39,18 @@ impl PyValue for PyDict {
4039
// Python dict methods:
4140
impl PyDictRef {
4241
fn new(
43-
_class: PyClassRef, // TODO Support subclasses of int.
42+
class: PyClassRef,
4443
dict_obj: OptionalArg<PyObjectRef>,
4544
kwargs: KwArgs,
4645
vm: &VirtualMachine,
4746
) -> PyResult<PyDictRef> {
48-
let dict = vm.ctx.new_dict();
47+
let mut dict = DictContentType::default();
48+
4949
if let OptionalArg::Present(dict_obj) = dict_obj {
5050
let dicted: PyResult<PyDictRef> = dict_obj.clone().downcast();
5151
if let Ok(dict_obj) = dicted {
5252
for (key, value) in dict_obj.get_key_value_pairs() {
53-
dict.set_item(key, value, vm);
53+
dict.insert(vm, &key, value)?;
5454
}
5555
} else {
5656
let iter = objiter::get_iter(vm, &dict_obj)?;
@@ -68,14 +68,17 @@ impl PyDictRef {
6868
if objiter::get_next_object(vm, &elem_iter)?.is_some() {
6969
return Err(err(vm));
7070
}
71-
dict.set_item(key, value, vm);
71+
dict.insert(vm, &key, value)?;
7272
}
7373
}
7474
}
7575
for (key, value) in kwargs.into_iter() {
76-
dict.set_item(vm.new_str(key), value, vm);
76+
dict.insert(vm, &vm.new_str(key), value)?;
77+
}
78+
PyDict {
79+
entries: RefCell::new(dict),
7780
}
78-
Ok(dict)
81+
.into_ref_with_type(vm, class)
7982
}
8083

8184
fn bool(self, _vm: &VirtualMachine) -> bool {
@@ -86,7 +89,7 @@ impl PyDictRef {
8689
self.entries.borrow().len()
8790
}
8891

89-
fn repr(self, vm: &VirtualMachine) -> PyResult {
92+
fn repr(self, vm: &VirtualMachine) -> PyResult<String> {
9093
let s = if let Some(_guard) = ReprGuard::enter(self.as_object()) {
9194
let mut str_parts = vec![];
9295
for (key, value) in self.get_key_value_pairs() {
@@ -99,14 +102,14 @@ impl PyDictRef {
99102
} else {
100103
"{...}".to_string()
101104
};
102-
Ok(vm.new_str(s))
105+
Ok(s)
103106
}
104107

105108
fn contains(self, key: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
106109
self.entries.borrow().contains(vm, &key)
107110
}
108111

109-
fn delitem(self, key: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
112+
fn inner_delitem(self, key: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
110113
self.entries.borrow_mut().delete(vm, &key)
111114
}
112115

@@ -170,25 +173,38 @@ impl PyDictRef {
170173
self.entries.borrow().get_items()
171174
}
172175

173-
fn setitem(self, key: PyObjectRef, value: PyObjectRef, vm: &VirtualMachine) {
174-
self.set_item(key, value, vm)
176+
fn inner_setitem(
177+
self,
178+
key: PyObjectRef,
179+
value: PyObjectRef,
180+
vm: &VirtualMachine,
181+
) -> PyResult<()> {
182+
self.entries.borrow_mut().insert(vm, &key, value)
175183
}
176184

177-
fn getitem(self, key: PyObjectRef, vm: &VirtualMachine) -> PyResult {
178-
self.entries.borrow().get(vm, &key)
185+
fn inner_getitem(self, key: PyObjectRef, vm: &VirtualMachine) -> PyResult {
186+
if let Some(value) = self.entries.borrow().get(vm, &key)? {
187+
return Ok(value);
188+
}
189+
190+
if let Ok(method) = vm.get_method(self.clone().into_object(), "__missing__") {
191+
return vm.invoke(method, vec![key]);
192+
}
193+
194+
Err(vm.new_key_error(format!("Key not found: {}", vm.to_pystr(&key)?)))
179195
}
180196

181197
fn get(
182198
self,
183199
key: PyObjectRef,
184200
default: OptionalArg<PyObjectRef>,
185201
vm: &VirtualMachine,
186-
) -> PyObjectRef {
187-
match self.into_object().get_item(key, vm) {
188-
Ok(value) => value,
189-
Err(_) => match default {
190-
OptionalArg::Present(value) => value,
191-
OptionalArg::Missing => vm.ctx.none(),
202+
) -> PyResult {
203+
match self.entries.borrow().get(vm, &key)? {
204+
Some(value) => Ok(value),
205+
None => match default {
206+
OptionalArg::Present(value) => Ok(value),
207+
OptionalArg::Missing => Ok(vm.ctx.none()),
192208
},
193209
}
194210
}
@@ -213,21 +229,22 @@ impl PyDictRef {
213229
}
214230
}
215231

216-
impl DictProtocol for PyDictRef {
217-
fn get_item<T: IntoPyObject>(&self, key: T, vm: &VirtualMachine) -> Option<PyObjectRef> {
218-
let key = key.into_pyobject(vm).unwrap();
219-
self.entries.borrow().get(vm, &key).ok()
232+
impl ItemProtocol for PyDictRef {
233+
fn get_item<T: IntoPyObject>(&self, key: T, vm: &VirtualMachine) -> PyResult {
234+
self.as_object().get_item(key, vm)
220235
}
221236

222-
// Item set/get:
223-
fn set_item<T: IntoPyObject>(&self, key: T, value: PyObjectRef, vm: &VirtualMachine) {
224-
let key = key.into_pyobject(vm).unwrap();
225-
self.entries.borrow_mut().insert(vm, &key, value).unwrap()
237+
fn set_item<T: IntoPyObject>(
238+
&self,
239+
key: T,
240+
value: PyObjectRef,
241+
vm: &VirtualMachine,
242+
) -> PyResult {
243+
self.as_object().set_item(key, value, vm)
226244
}
227245

228-
fn del_item<T: IntoPyObject>(&self, key: T, vm: &VirtualMachine) {
229-
let key = key.into_pyobject(vm).unwrap();
230-
self.entries.borrow_mut().delete(vm, &key).unwrap();
246+
fn del_item<T: IntoPyObject>(&self, key: T, vm: &VirtualMachine) -> PyResult {
247+
self.as_object().del_item(key, vm)
231248
}
232249
}
233250

@@ -236,12 +253,12 @@ pub fn init(context: &PyContext) {
236253
"__bool__" => context.new_rustfunc(PyDictRef::bool),
237254
"__len__" => context.new_rustfunc(PyDictRef::len),
238255
"__contains__" => context.new_rustfunc(PyDictRef::contains),
239-
"__delitem__" => context.new_rustfunc(PyDictRef::delitem),
240-
"__getitem__" => context.new_rustfunc(PyDictRef::getitem),
256+
"__delitem__" => context.new_rustfunc(PyDictRef::inner_delitem),
257+
"__getitem__" => context.new_rustfunc(PyDictRef::inner_getitem),
241258
"__iter__" => context.new_rustfunc(PyDictRef::iter),
242259
"__new__" => context.new_rustfunc(PyDictRef::new),
243260
"__repr__" => context.new_rustfunc(PyDictRef::repr),
244-
"__setitem__" => context.new_rustfunc(PyDictRef::setitem),
261+
"__setitem__" => context.new_rustfunc(PyDictRef::inner_setitem),
245262
"__hash__" => context.new_rustfunc(PyDictRef::hash),
246263
"clear" => context.new_rustfunc(PyDictRef::clear),
247264
"values" => context.new_rustfunc(PyDictRef::values),

vm/src/obj/objobject.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::function::PyFuncArgs;
66
use crate::obj::objproperty::PropertyBuilder;
77
use crate::obj::objtype::PyClassRef;
88
use crate::pyobject::{
9-
DictProtocol, IdProtocol, PyAttributes, PyContext, PyObject, PyObjectRef, PyResult, PyValue,
9+
IdProtocol, ItemProtocol, PyAttributes, PyContext, PyObject, PyObjectRef, PyResult, PyValue,
1010
TryFromObject, TypeProtocol,
1111
};
1212
use crate::vm::VirtualMachine;
@@ -77,7 +77,7 @@ fn object_setattr(
7777
}
7878

7979
if let Some(ref dict) = obj.clone().dict {
80-
dict.set_item(attr_name, value, vm);
80+
dict.set_item(attr_name, value, vm)?;
8181
Ok(())
8282
} else {
8383
Err(vm.new_attribute_error(format!(
@@ -98,7 +98,7 @@ fn object_delattr(obj: PyObjectRef, attr_name: PyStringRef, vm: &VirtualMachine)
9898
}
9999

100100
if let Some(ref dict) = obj.dict {
101-
dict.del_item(attr_name, vm);
101+
dict.del_item(attr_name, vm)?;
102102
Ok(())
103103
} else {
104104
Err(vm.new_attribute_error(format!(
@@ -208,7 +208,7 @@ fn object_getattribute(obj: PyObjectRef, name_str: PyStringRef, vm: &VirtualMach
208208
}
209209
}
210210

211-
if let Some(obj_attr) = object_getattr(&obj, &name, &vm) {
211+
if let Some(obj_attr) = object_getattr(&obj, &name, &vm)? {
212212
Ok(obj_attr)
213213
} else if let Some(attr) = objtype::class_get_attr(&cls, &name) {
214214
vm.call_get_descriptor(attr, obj)
@@ -219,11 +219,15 @@ fn object_getattribute(obj: PyObjectRef, name_str: PyStringRef, vm: &VirtualMach
219219
}
220220
}
221221

222-
fn object_getattr(obj: &PyObjectRef, attr_name: &str, vm: &VirtualMachine) -> Option<PyObjectRef> {
222+
fn object_getattr(
223+
obj: &PyObjectRef,
224+
attr_name: &str,
225+
vm: &VirtualMachine,
226+
) -> PyResult<Option<PyObjectRef>> {
223227
if let Some(ref dict) = obj.dict {
224-
dict.get_item(attr_name, vm)
228+
dict.get_item_option(attr_name, vm)
225229
} else {
226-
None
230+
Ok(None)
227231
}
228232
}
229233

vm/src/obj/objsuper.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::obj::objfunction::PyMethod;
1212
use crate::obj::objstr;
1313
use crate::obj::objtype::{PyClass, PyClassRef};
1414
use crate::pyobject::{
15-
DictProtocol, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol,
15+
ItemProtocol, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol,
1616
};
1717
use crate::vm::VirtualMachine;
1818

@@ -124,7 +124,7 @@ fn super_new(
124124
} else {
125125
let frame = vm.current_frame().expect("no current frame for super()");
126126
if let Some(first_arg) = frame.code.arg_names.get(0) {
127-
match vm.get_locals().get_item(first_arg, vm) {
127+
match vm.get_locals().get_item_option(first_arg, vm)? {
128128
Some(obj) => obj.clone(),
129129
_ => {
130130
return Err(vm

0 commit comments

Comments
 (0)