Skip to content

Commit c9b479c

Browse files
committed
Guard for changes in dictionary size during iteration.
1 parent 5855baf commit c9b479c

File tree

3 files changed

+54
-14
lines changed

3 files changed

+54
-14
lines changed

tests/snippets/dict.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,18 @@
6969
data[3] = "changed"
7070
assert (3, "changed") == next(items)
7171

72+
# But we can't add or delete items during iteration.
73+
d = {}
74+
a = iter(d.items())
75+
d['a'] = 2
76+
b = iter(d.items())
77+
assert ('a', 2) == next(b)
78+
with assertRaises(RuntimeError):
79+
next(a)
80+
del d['a']
81+
with assertRaises(RuntimeError):
82+
next(b)
83+
7284
# View isn't itself an iterator.
7385
with assertRaises(TypeError):
7486
next(data.keys())

vm/src/dictdatatype.rs

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ struct DictEntry<T> {
3333
value: T,
3434
}
3535

36+
#[derive(Debug)]
37+
pub struct DictSize {
38+
size: usize,
39+
entries_size: usize,
40+
}
41+
3642
impl<T: Clone> Dict<T> {
3743
pub fn new() -> Self {
3844
Dict {
@@ -120,17 +126,28 @@ impl<T: Clone> Dict<T> {
120126
self.len() == 0
121127
}
122128

123-
pub fn next_entry(&self, mut position: usize) -> Option<(usize, &PyObjectRef, &T)> {
124-
while position < self.entries.len() {
125-
if let Some(DictEntry { key, value, .. }) = &self.entries[position] {
126-
return Some((position + 1, key, value));
127-
} else {
128-
position += 1;
129+
pub fn size(&self) -> DictSize {
130+
DictSize {
131+
size: self.size,
132+
entries_size: self.entries.len(),
133+
}
134+
}
135+
136+
pub fn next_entry(&self, position: &mut usize) -> Option<(&PyObjectRef, &T)> {
137+
while *position < self.entries.len() {
138+
if let Some(DictEntry { key, value, .. }) = &self.entries[*position] {
139+
*position += 1;
140+
return Some((key, value));
129141
}
142+
*position += 1;
130143
}
131144
None
132145
}
133146

147+
pub fn has_changed_size(&self, position: &DictSize) -> bool {
148+
position.size != self.size || self.entries.len() != position.entries_size
149+
}
150+
134151
/// Lookup the index for the given key.
135152
fn lookup(&self, vm: &VirtualMachine, key: &PyObjectRef) -> PyResult<LookupResult> {
136153
let hash_value = calc_hash(vm, key)?;

vm/src/obj/objdict.rs

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,10 @@ impl PyDictRef {
260260
let key = key.into_pyobject(vm).unwrap();
261261
self.entries.borrow().contains(vm, &key).unwrap()
262262
}
263+
264+
pub fn size(&self) -> dictdatatype::DictSize {
265+
self.entries.borrow().size()
266+
}
263267
}
264268

265269
impl ItemProtocol for PyDictRef {
@@ -315,11 +319,8 @@ impl Iterator for DictIter {
315319
type Item = (PyObjectRef, PyObjectRef);
316320

317321
fn next(&mut self) -> Option<Self::Item> {
318-
match self.dict.entries.borrow().next_entry(self.position) {
319-
Some((new_position, key, value)) => {
320-
self.position = new_position;
321-
Some((key.clone(), value.clone()))
322-
}
322+
match self.dict.entries.borrow().next_entry(&mut self.position) {
323+
Some((key, value)) => Some((key.clone(), value.clone())),
323324
None => None,
324325
}
325326
}
@@ -360,6 +361,7 @@ macro_rules! dict_iterator {
360361
#[derive(Debug)]
361362
struct $iter_name {
362363
pub dict: PyDictRef,
364+
pub size: dictdatatype::DictSize,
363365
pub position: Cell<usize>,
364366
}
365367

@@ -368,15 +370,24 @@ macro_rules! dict_iterator {
368370
fn new(dict: PyDictRef) -> Self {
369371
$iter_name {
370372
position: Cell::new(0),
373+
size: dict.size(),
371374
dict,
372375
}
373376
}
374377

375378
#[pymethod(name = "__next__")]
376379
fn next(&self, vm: &VirtualMachine) -> PyResult {
377-
match self.dict.entries.borrow().next_entry(self.position.get()) {
378-
Some((new_position, key, value)) => {
379-
self.position.set(new_position);
380+
let mut position = self.position.get();
381+
let dict = self.dict.entries.borrow();
382+
if dict.has_changed_size(&self.size) {
383+
return Err(vm.new_exception(
384+
vm.ctx.exceptions.runtime_error.clone(),
385+
"dictionary changed size during iteration".to_string(),
386+
));
387+
}
388+
match dict.next_entry(&mut position) {
389+
Some((key, value)) => {
390+
self.position.set(position);
380391
Ok($result_fn(vm, key, value))
381392
}
382393
None => Err(objiter::new_stop_iteration(vm)),

0 commit comments

Comments
 (0)