Skip to content

Commit 2c3f4b6

Browse files
authored
Merge pull request RustPython#816 from RustPython/dict_changed_during_iteration
Guard for changes in dictionary size during iteration.
2 parents 5855baf + dfbaf3c commit 2c3f4b6

File tree

3 files changed

+55
-16
lines changed

3 files changed

+55
-16
lines changed

tests/snippets/dict.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,18 @@
6969
data[3] = "changed"
7070
assert (3, "changed") == next(items)
7171

72+
# But we can't add or delete items during iteration.
73+
d = {}
74+
a = iter(d.items())
75+
d['a'] = 2
76+
b = iter(d.items())
77+
assert ('a', 2) == next(b)
78+
with assertRaises(RuntimeError):
79+
next(a)
80+
del d['a']
81+
with assertRaises(RuntimeError):
82+
next(b)
83+
7284
# View isn't itself an iterator.
7385
with assertRaises(TypeError):
7486
next(data.keys())

vm/src/dictdatatype.rs

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ struct DictEntry<T> {
3333
value: T,
3434
}
3535

36+
#[derive(Debug)]
37+
pub struct DictSize {
38+
size: usize,
39+
entries_size: usize,
40+
}
41+
3642
impl<T: Clone> Dict<T> {
3743
pub fn new() -> Self {
3844
Dict {
@@ -120,17 +126,28 @@ impl<T: Clone> Dict<T> {
120126
self.len() == 0
121127
}
122128

123-
pub fn next_entry(&self, mut position: usize) -> Option<(usize, &PyObjectRef, &T)> {
124-
while position < self.entries.len() {
125-
if let Some(DictEntry { key, value, .. }) = &self.entries[position] {
126-
return Some((position + 1, key, value));
127-
} else {
128-
position += 1;
129+
pub fn size(&self) -> DictSize {
130+
DictSize {
131+
size: self.size,
132+
entries_size: self.entries.len(),
133+
}
134+
}
135+
136+
pub fn next_entry(&self, position: &mut usize) -> Option<(&PyObjectRef, &T)> {
137+
while *position < self.entries.len() {
138+
if let Some(DictEntry { key, value, .. }) = &self.entries[*position] {
139+
*position += 1;
140+
return Some((key, value));
129141
}
142+
*position += 1;
130143
}
131144
None
132145
}
133146

147+
pub fn has_changed_size(&self, position: &DictSize) -> bool {
148+
position.size != self.size || self.entries.len() != position.entries_size
149+
}
150+
134151
/// Lookup the index for the given key.
135152
fn lookup(&self, vm: &VirtualMachine, key: &PyObjectRef) -> PyResult<LookupResult> {
136153
let hash_value = calc_hash(vm, key)?;

vm/src/obj/objdict.rs

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ pub type DictContentType = dictdatatype::Dict;
1919

2020
#[derive(Default)]
2121
pub struct PyDict {
22-
// TODO: should be private
23-
pub entries: RefCell<DictContentType>,
22+
entries: RefCell<DictContentType>,
2423
}
2524
pub type PyDictRef = PyRef<PyDict>;
2625

@@ -260,6 +259,10 @@ impl PyDictRef {
260259
let key = key.into_pyobject(vm).unwrap();
261260
self.entries.borrow().contains(vm, &key).unwrap()
262261
}
262+
263+
pub fn size(&self) -> dictdatatype::DictSize {
264+
self.entries.borrow().size()
265+
}
263266
}
264267

265268
impl ItemProtocol for PyDictRef {
@@ -315,11 +318,8 @@ impl Iterator for DictIter {
315318
type Item = (PyObjectRef, PyObjectRef);
316319

317320
fn next(&mut self) -> Option<Self::Item> {
318-
match self.dict.entries.borrow().next_entry(self.position) {
319-
Some((new_position, key, value)) => {
320-
self.position = new_position;
321-
Some((key.clone(), value.clone()))
322-
}
321+
match self.dict.entries.borrow().next_entry(&mut self.position) {
322+
Some((key, value)) => Some((key.clone(), value.clone())),
323323
None => None,
324324
}
325325
}
@@ -360,6 +360,7 @@ macro_rules! dict_iterator {
360360
#[derive(Debug)]
361361
struct $iter_name {
362362
pub dict: PyDictRef,
363+
pub size: dictdatatype::DictSize,
363364
pub position: Cell<usize>,
364365
}
365366

@@ -368,15 +369,24 @@ macro_rules! dict_iterator {
368369
fn new(dict: PyDictRef) -> Self {
369370
$iter_name {
370371
position: Cell::new(0),
372+
size: dict.size(),
371373
dict,
372374
}
373375
}
374376

375377
#[pymethod(name = "__next__")]
376378
fn next(&self, vm: &VirtualMachine) -> PyResult {
377-
match self.dict.entries.borrow().next_entry(self.position.get()) {
378-
Some((new_position, key, value)) => {
379-
self.position.set(new_position);
379+
let mut position = self.position.get();
380+
let dict = self.dict.entries.borrow();
381+
if dict.has_changed_size(&self.size) {
382+
return Err(vm.new_exception(
383+
vm.ctx.exceptions.runtime_error.clone(),
384+
"dictionary changed size during iteration".to_string(),
385+
));
386+
}
387+
match dict.next_entry(&mut position) {
388+
Some((key, value)) => {
389+
self.position.set(position);
380390
Ok($result_fn(vm, key, value))
381391
}
382392
None => Err(objiter::new_stop_iteration(vm)),

0 commit comments

Comments
 (0)