Skip to content

Commit

Permalink
Please mypy and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Corvince committed Feb 5, 2021
1 parent 4e1e52a commit 3440edd
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 7 deletions.
46 changes: 40 additions & 6 deletions mesa/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,21 @@

import numpy as np

from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union
from mesa.agent import Agent
from typing import (
Any,
Dict,
Iterable,
Iterator,
List,
Optional,
Set,
Sequence,
Tuple,
Union,
cast,
overload,
)
from .agent import Agent

Coordinate = Tuple[int, int]
GridContent = Union[Optional[Agent], Set[Agent]]
Expand Down Expand Up @@ -105,27 +118,47 @@ def default_val() -> None:
""" Default value for new cell elements. """
return None

@overload
def __getitem__(self, index: int) -> List[GridContent]:
...

@overload
def __getitem__(
self, index: Tuple[Union[int, slice], Union[int, slice]]
) -> Union[GridContent, List[GridContent]]:
...

@overload
def __getitem__(self, index: Sequence[Coordinate]) -> List[GridContent]:
...

def __getitem__(
self,
index: Union[int, Tuple[int, int], Tuple[slice, slice], Tuple[Coordinate]],
) -> Union[List[GridContent], GridContent]:
index: Union[
int, Sequence[Coordinate], Tuple[Union[int, slice], Union[int, slice]],
],
) -> Union[GridContent, List[GridContent]]:
"""Access contents from the grid."""

if isinstance(index, int):
# grid[x]
return self.grid[index]

if isinstance(index[0], tuple):
# grid[(x1, y1), (x2, y2)]
index = cast(Sequence[Coordinate], index)

cells = []
for pos in index:
x, y = self.torus_adj(pos)
cells.append(self.grid[x][y])
x1, y1 = self.torus_adj(pos)
cells.append(self.grid[x1][y1])
return cells

x, y = index

if isinstance(x, int) and isinstance(y, int):
# grid[x, y]
index = cast(Coordinate, index)
x, y = self.torus_adj(index)
return self.grid[x][y]

Expand All @@ -140,6 +173,7 @@ def __getitem__(
y = slice(y, y + 1)

# grid[:, :]
x, y = (cast(slice, x), cast(slice, y))
cells = []
for rows in self.grid[x]:
for cell in rows[y]:
Expand Down
28 changes: 27 additions & 1 deletion tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,5 +431,31 @@ def test_neighbors(self):
assert len(neighborhood) == 6


if __name__ == "__main__":
class TestIndexing:
# Create a grid where the content of each coordinate is a tuple of its coordinates
grid = Grid(3, 5, True)
for _, x, y in grid.coord_iter():
grid.grid[x][y] = (x, y)

def test_int(self):
assert self.grid[0][0] == (0, 0)

def test_tuple(self):
assert self.grid[1, 1] == (1, 1)

def test_list(self):
assert self.grid[(0, 0), (1, 1)] == [(0, 0), (1, 1)]
assert self.grid[(0, 0), (5, 3)] == [(0, 0), (2, 3)]

def test_torus(self):
assert self.grid[3, 5] == (0, 0)

def test_slice(self):
assert self.grid[:, 0] == [(0, 0), (1, 0), (2, 0)]
assert self.grid[::-1, 0] == [(2, 0), (1, 0), (0, 0)]
assert self.grid[1, :] == [(1, 0), (1, 1), (1, 2), (1, 3), (1, 4)]
assert self.grid[:, :] == [(x, y) for x in range(3) for y in range(5)]


if __name__ == '__main__':
unittest.main()

0 comments on commit 3440edd

Please sign in to comment.