Skip to content

Commit a26e5f2

Browse files
committed
Add dict iterator __length_hint__
1 parent 098393d commit a26e5f2

File tree

3 files changed

+34
-13
lines changed

3 files changed

+34
-13
lines changed

tests/snippets/dict.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,16 @@ def __eq__(self, other):
254254
assert not {}.__ne__({})
255255
assert {}.__ne__({'a':'b'})
256256
assert {}.__ne__(1) == NotImplemented
257+
258+
it = iter({0: 1, 2: 3, 4:5, 6:7})
259+
assert it.__length_hint__() == 4
260+
next(it)
261+
assert it.__length_hint__() == 3
262+
next(it)
263+
assert it.__length_hint__() == 2
264+
next(it)
265+
assert it.__length_hint__() == 1
266+
next(it)
267+
assert it.__length_hint__() == 0
268+
assert_raises(StopIteration, next, it)
269+
assert it.__length_hint__() == 0

vm/src/dictdatatype.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -204,26 +204,26 @@ impl<T: Clone> Dict<T> {
204204
}
205205

206206
pub fn next_entry(&self, position: &mut EntryIndex) -> Option<(&PyObjectRef, &T)> {
207-
while *position < self.entries.len() {
208-
if let Some(DictEntry { key, value, .. }) = &self.entries[*position] {
209-
*position += 1;
210-
return Some((key, value));
211-
}
207+
self.entries[*position..].iter().find_map(|entry| {
212208
*position += 1;
213-
}
214-
None
209+
entry
210+
.as_ref()
211+
.map(|DictEntry { key, value, .. }| (key, value))
212+
})
213+
}
214+
215+
pub fn len_from_entry_index(&self, position: EntryIndex) -> usize {
216+
self.entries[position..].iter().flatten().count()
215217
}
216218

217219
pub fn has_changed_size(&self, position: &DictSize) -> bool {
218220
position.size != self.size || self.entries.len() != position.entries_size
219221
}
220222

221-
pub fn keys<'a>(&'a self) -> Box<dyn Iterator<Item = PyObjectRef> + 'a> {
222-
Box::new(
223-
self.entries
224-
.iter()
225-
.filter_map(|v| v.as_ref().map(|v| v.key.clone())),
226-
)
223+
pub fn keys<'a>(&'a self) -> impl Iterator<Item = PyObjectRef> + 'a {
224+
self.entries
225+
.iter()
226+
.filter_map(|v| v.as_ref().map(|v| v.key.clone()))
227227
}
228228

229229
/// Lookup the index for the given key.

vm/src/obj/objdict.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,14 @@ macro_rules! dict_iterator {
545545
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
546546
zelf
547547
}
548+
549+
#[pymethod(name = "__length_hint__")]
550+
fn length_hint(&self, _vm: &VirtualMachine) -> usize {
551+
self.dict
552+
.entries
553+
.borrow()
554+
.len_from_entry_index(self.position.get())
555+
}
548556
}
549557

550558
impl PyValue for $iter_name {

0 commit comments

Comments
 (0)