Skip to content

Commit

Permalink
[mojo-stdlib] Dict iteration over keys, value, entries (#30936)
Browse files Browse the repository at this point in the history
- `Dict.__iter__` iterates over dict keys, immutably
- `Dict.keys()` iterates over dict keys, immutably
- Note that this is different from Python. `Dict.keys()` in Python is a
`key view` collection, which is immutable and implements the semantics
of a set. What we're doing here isn't incompatible with that long-term.
- `Dict.items()` iterates over dict entries, immutably
  - If there's a way to just allow mutability on `value` I don't see it
- Unlike Python these can't be unpacked to `k, v`. We could implement
that with [Internal Link]
- `Dict.values()` iterates over dict values, **maybe mutably**.
- This is okay because mutating the values can't impact the integrity of
the dictionary internals.

Stack:
- [Internal Link]
- [Internal Link]
- [Internal Link]
- -> [Internal Link]

modular-orig-commit: 7c9ee251a9b407492f61facf9f3f79c750d19790
  • Loading branch information
bethebunny authored Feb 6, 2024
1 parent e6356e0 commit be84840
Show file tree
Hide file tree
Showing 2 changed files with 284 additions and 24 deletions.
213 changes: 213 additions & 0 deletions stdlib/src/collections/dict.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,141 @@ trait KeyElement(CollectionElement, Hashable, EqualityComparable):
pass


@value
struct _DictEntryIter[
K: KeyElement,
V: CollectionElement,
dict_mutability: __mlir_type.`i1`,
dict_lifetime: AnyLifetime[dict_mutability].type,
]:
"""Iterator over immutable DictEntry references.
Parameters:
K: The key type of the elements in the dictionary.
V: The value type of the elements in the dictionary.
dict_mutability: Whether the reference to the dictionary is mutable.
dict_lifetime: The lifetime of the DynamicVector
"""

alias imm_dict_lifetime = __mlir_attr[
`#lit.lifetime.mutcast<`, dict_lifetime, `> : !lit.lifetime<1>`
]
alias ref_type = Reference[
DictEntry[K, V], __mlir_attr.`0: i1`, Self.imm_dict_lifetime
]

var index: Int
var seen: Int
var src: Reference[Dict[K, V], dict_mutability, dict_lifetime]

fn __iter__(self) -> Self:
return self

fn __next__(inout self) -> Self.ref_type:
while True:
debug_assert(self.index < self.src[]._reserved, "dict iter bounds")
if self.src[]._entries[self.index]:
let opt_entry_ref = self.src[]._entries.__get_ref[
__mlir_attr.`0: i1`,
Self.imm_dict_lifetime,
](self.index)
self.index += 1
self.seen += 1
# Super unsafe, but otherwise we have to do a bunch of super
# unsafe reference lifetime casting.
return opt_entry_ref.bitcast_element[DictEntry[K, V]]()
self.index += 1

fn __len__(self) -> Int:
return len(self.src[]) - self.seen


@value
struct _DictKeyIter[
K: KeyElement,
V: CollectionElement,
dict_mutability: __mlir_type.`i1`,
dict_lifetime: AnyLifetime[dict_mutability].type,
]:
"""Iterator over immutable Dict key references.
Parameters:
K: The key type of the elements in the dictionary.
V: The value type of the elements in the dictionary.
dict_mutability: Whether the reference to the vector is mutable.
dict_lifetime: The lifetime of the DynamicVector
"""

alias imm_dict_lifetime = __mlir_attr[
`#lit.lifetime.mutcast<`, dict_lifetime, `> : !lit.lifetime<1>`
]
alias ref_type = Reference[K, __mlir_attr.`0: i1`, Self.imm_dict_lifetime]

var iter: _DictEntryIter[K, V, dict_mutability, dict_lifetime]

fn __iter__(self) -> Self:
return self

fn __next__(inout self) -> Self.ref_type:
let entry_ref = self.iter.__next__()
let mlir_ptr = __mlir_op.`lit.ref.to_pointer`(
Reference(entry_ref[].key).value
)
let key_ptr = AnyPointer[K] {
value: __mlir_op.`pop.pointer.bitcast`[
_type = AnyPointer[K].pointer_type
](mlir_ptr)
}
return __mlir_op.`lit.ref.from_pointer`[
_type = Self.ref_type.mlir_ref_type
](key_ptr.value)

fn __len__(self) -> Int:
return self.iter.__len__()


@value
struct _DictValueIter[
K: KeyElement,
V: CollectionElement,
dict_mutability: __mlir_type.`i1`,
dict_lifetime: AnyLifetime[dict_mutability].type,
]:
"""Iterator over Dict value references. These are mutable if the dict
is mutable.
Parameters:
K: The key type of the elements in the dictionary.
V: The value type of the elements in the dictionary.
dict_mutability: Whether the reference to the vector is mutable.
dict_lifetime: The lifetime of the DynamicVector
"""

alias ref_type = Reference[V, dict_mutability, dict_lifetime]

var iter: _DictEntryIter[K, V, dict_mutability, dict_lifetime]

fn __iter__(self) -> Self:
return self

fn __next__(inout self) -> Self.ref_type:
let entry_ref = self.iter.__next__()
let mlir_ptr = __mlir_op.`lit.ref.to_pointer`(
Reference(entry_ref[].value).value
)
let value_ptr = AnyPointer[V] {
value: __mlir_op.`pop.pointer.bitcast`[
_type = AnyPointer[V].pointer_type
](mlir_ptr)
}
return __mlir_op.`lit.ref.from_pointer`[
_type = Self.ref_type.mlir_ref_type
](value_ptr.value)

fn __len__(self) -> Int:
return self.iter.__len__()


@value
struct DictEntry[K: KeyElement, V: CollectionElement](CollectionElement):
"""Store a key-value pair entry inside a dictionary.
Expand Down Expand Up @@ -414,6 +549,84 @@ struct Dict[K: KeyElement, V: CollectionElement](Sized):
return default.value()
raise "KeyError"

fn __iter__[
mutability: __mlir_type.`i1`, self_life: AnyLifetime[mutability].type
](
self: Reference[Self, mutability, self_life].mlir_ref_type,
) -> _DictKeyIter[K, V, mutability, self_life]:
"""Iterate over the dict's keys as immutable references.
Parameters:
mutability: Whether the dict is mutable.
self_life: The dict's lifetime.
Returns:
An iterator of immutable references to the dictionary keys.
"""
return _DictKeyIter(
_DictEntryIter[K, V, mutability, self_life](0, 0, Reference(self))
)

fn keys[
mutability: __mlir_type.`i1`, self_life: AnyLifetime[mutability].type
](
self: Reference[Self, mutability, self_life].mlir_ref_type,
) -> _DictKeyIter[K, V, mutability, self_life]:
"""Iterate over the dict's keys as immutable references.
Parameters:
mutability: Whether the dict is mutable.
self_life: The dict's lifetime.
Returns:
An iterator of immutable references to the dictionary keys.
"""
return Self.__iter__(self)

fn values[
mutability: __mlir_type.`i1`, self_life: AnyLifetime[mutability].type
](
self: Reference[Self, mutability, self_life].mlir_ref_type,
) -> _DictValueIter[K, V, mutability, self_life]:
"""Iterate over the dict's values as references.
Parameters:
mutability: Whether the dict is mutable.
self_life: The dict's lifetime.
Returns:
An iterator of references to the dictionary values.
"""
return _DictValueIter(
_DictEntryIter[K, V, mutability, self_life](0, 0, Reference(self))
)

fn items[
mutability: __mlir_type.`i1`, self_life: AnyLifetime[mutability].type
](
self: Reference[Self, mutability, self_life].mlir_ref_type,
) -> _DictEntryIter[K, V, mutability, self_life]:
"""Iterate over the dict's entries as immutable references.
These can't yet be unpacked like Python dict items, but you can
access the key and value as attributes ie.
```mojo
for e in dict.items():
print(e[].key, e[].value)
```
Parameters:
mutability: Whether the dict is mutable.
self_life: The dict's lifetime.
Returns:
An iterator of immutable references to the dictionary entries.
"""
return _DictEntryIter[K, V, mutability, self_life](
0, 0, Reference(self)
)

@staticmethod
fn _new_entries(reserved: Int) -> DynamicVector[Optional[DictEntry[K, V]]]:
var entries = DynamicVector[Optional[DictEntry[K, V]]](reserved)
Expand Down
95 changes: 71 additions & 24 deletions stdlib/test/collections/test_dict.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,6 @@ from collections import Optional
from testing import *


@value
struct assert_raises:
var message: Optional[StringLiteral]

fn __enter__(self) -> Self:
return self

fn __exit__(self) raises:
var message = String("Test didn't raise!")
if self.message:
message += " Expected: " + str(self.message.value())
assert_true(False, message)

fn __exit__(self, error: Error) raises -> Bool:
let message = str(error)
if self.message:
let expected = String(self.message.value())
return expected == message
else:
return True


def test_dict_construction():
_ = Dict[Int, Int]()
_ = Dict[String, Int]()
Expand Down Expand Up @@ -84,12 +62,76 @@ def test_pop_default():
def test_key_error():
var dict = Dict[String, Int]()

with assert_raises("KeyError"):
with assert_raises(contains="KeyError"):
_ = dict["a"]
with assert_raises("KeyError"):
with assert_raises(contains="KeyError"):
_ = dict.pop("a")


def test_iter():
var dict = Dict[String, Int]()
dict["a"] = 1
dict["b"] = 2

var keys = String("")
for key in dict:
keys += key[]

assert_equal(keys, "ab")


def test_iter_keys():
var dict = Dict[String, Int]()
dict["a"] = 1
dict["b"] = 2

var keys = String("")
for key in dict.keys():
keys += key[]

assert_equal(keys, "ab")


def test_iter_values():
var dict = Dict[String, Int]()
dict["a"] = 1
dict["b"] = 2

var sum = 0
for value in dict.values():
sum += value[]

assert_equal(sum, 3)


def test_iter_values_mut():
var dict = Dict[String, Int]()
dict["a"] = 1
dict["b"] = 2

for value in dict.values():
value[] += 1

assert_equal(2, dict["a"])
assert_equal(3, dict["b"])
assert_equal(2, len(dict))


def test_iter_items():
var dict = Dict[String, Int]()
dict["a"] = 1
dict["b"] = 2

var keys = String("")
var sum = 0
for entry in dict.items():
keys += entry[].key
sum += entry[].value

assert_equal(keys, "ab")
assert_equal(sum, 3)


fn test[name: String, test_fn: fn () raises -> object]() raises:
var name_val = name # FIXME(#26974): Can't pass 'name' directly.
print_no_newline("Test", name_val, "...")
Expand All @@ -109,3 +151,8 @@ def main():
test["test_compact", test_compact]()
test["test_pop_default", test_pop_default]()
test["test_key_error", test_key_error]()
test["test_iter", test_iter]()
test["test_iter_keys", test_iter_keys]()
test["test_iter_values", test_iter_values]()
test["test_iter_values_mut", test_iter_values_mut]()
test["test_iter_items", test_iter_items]()

0 comments on commit be84840

Please sign in to comment.