Skip to content

Commit

Permalink
Add a repr that extends the normal repr but also outputs the names. A…
Browse files Browse the repository at this point in the history
…lso fix slicing into a dimension that doesn't have names.

Fixes google-deepmind#174

PiperOrigin-RevId: 196511049
  • Loading branch information
tewalds committed May 14, 2018
1 parent 3a97a7e commit 47713aa
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
26 changes: 24 additions & 2 deletions pysc2/lib/named_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from __future__ import print_function

import numbers
import re

import enum
import numpy as np
Expand Down Expand Up @@ -115,7 +116,7 @@ def __new__(cls, values, names, *args, **kwargs):

# Finally convert to a NamedNumpyArray.
obj = obj.view(cls)
obj._index_names = index_names
obj._index_names = index_names # [{name: index}, ...], dict per dimension.
return obj

def __array_finalize__(self, obj):
Expand Down Expand Up @@ -146,7 +147,7 @@ def __getitem__(self, indices):
if isinstance(obj, np.ndarray): # If this is a view, index the names too.
if isinstance(index, numbers.Integral):
obj._index_names = obj._index_names[1:]
elif isinstance(index, slice):
elif isinstance(index, slice) and self._index_names[0]:
# Rebuild the index of names.
names = sorted(obj._index_names[0].items(), key=lambda item: item[1])
sliced = {n: i for i, (n, _) in enumerate(names[index])}
Expand All @@ -169,8 +170,29 @@ def __getslice__(self, i, j): # deprecated, but still needed...
def __setslice__(self, i, j, seq): # deprecated, but still needed...
self[max(0, i):max(0, j):] = seq

def __repr__(self):
"""A repr, parsing the original and adding the names param."""
names = []
for dim_names in self._index_names:
if dim_names:
dim_names = [n for n, _ in sorted(dim_names.items(),
key=lambda item: item[1])]
if len(dim_names) > 11:
dim_names = dim_names[:5] + ["..."] + dim_names[-5:]
names.append(dim_names)
if len(names) == 1:
names = names[0]

# "NamedNumpyArray([1, 3, 6], dtype=int32)" ->
# ["NamedNumpyArray", "[1, 3, 6]", ", dtype=int32"]
matches = re.findall(r"^(\w+)\(([\d\., \n\[\]]*)(, \w+=.+)?\)$",
np.array_repr(self))[0]
return "%s(%s, %s%s)" % (
matches[0], matches[1], names, matches[2])


def _get_index(obj, index):
"""Turn a generalized index (int/slice/str) into a real index (int/slice)."""
if isinstance(index, (numbers.Integral, slice)):
return index
elif isinstance(index, str):
Expand Down
25 changes: 25 additions & 0 deletions pysc2/lib/named_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,31 @@ def test_named_array_multi_second(self):
with self.assertRaises(TypeError):
a.a # pylint: disable=pointless-statement

def test_string(self):
a = named_array.NamedNumpyArray([1, 3, 6], ["a", "b", "c"], dtype=np.int32)
self.assertEqual(str(a), "[1 3 6]")
self.assertEqual(repr(a), ("NamedNumpyArray([1, 3, 6], ['a', 'b', 'c'], "
"dtype=int32)"))

a = named_array.NamedNumpyArray([[1, 3], [6, 8]], [None, ["a", "b"]])
self.assertEqual(str(a), "[[1 3]\n [6 8]]")
self.assertEqual(repr(a), ("NamedNumpyArray([[1, 3],\n"
" [6, 8]], [None, ['a', 'b']])"))

a = named_array.NamedNumpyArray([[1, 3], [6, 8]], [["a", "b"], None])
self.assertEqual(str(a), "[[1 3]\n [6 8]]")
self.assertEqual(repr(a), ("NamedNumpyArray([[1, 3],\n"
" [6, 8]], [['a', 'b'], None])"))

a = named_array.NamedNumpyArray([list(range(50))] * 50,
[None, ["a%s" % i for i in range(50)]])
self.assertIn("49", str(a))
self.assertIn("49", repr(a))

a = named_array.NamedNumpyArray([list(range(50))] * 50,
[["a%s" % i for i in range(50)], None])
self.assertIn("49", str(a))
self.assertIn("49", repr(a))

if __name__ == "__main__":
absltest.main()

0 comments on commit 47713aa

Please sign in to comment.