Skip to content

Commit

Permalink
Fix indexing NamedNumpyArray with newaxis.
Browse files Browse the repository at this point in the history
Fixes google-deepmind#273

PiperOrigin-RevId: 268685149
  • Loading branch information
tewalds committed Sep 13, 2019
1 parent 5ca04db commit e994d8b
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 6 deletions.
18 changes: 12 additions & 6 deletions pysc2/lib/named_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,21 @@ def __getitem__(self, indices):
dim = 0
for i, index in enumerate(indices):
if isinstance(index, numbers.Integral):
pass # Drop this dimension's names.
dim += 1 # Drop this dimension's names.
elif index is Ellipsis:
# Copy all the dimensions' names through.
end = len(self.shape) - len(indices) + i
for j in range(dim, end + 1):
end = len(self.shape) - len(indices) + i + 1
for j in range(dim, end):
new_names.append(self._index_names[j])
dim = end
elif index is np.newaxis: # Add an unnamed dimension.
new_names.append(None)
# Don't modify dim, as we're still working on the same one.
elif (self._index_names[dim] is None or
(isinstance(index, slice) and index == _NULL_SLICE)):
# Keep unnamed dimensions or ones where the slice is a no-op.
new_names.append(self._index_names[dim])
dim += 1
elif isinstance(index, (slice, list, np.ndarray)):
if isinstance(index, np.ndarray) and len(index.shape) > 1:
raise TypeError("What does it mean to index into a named array by "
Expand All @@ -191,9 +195,9 @@ def __getitem__(self, indices):
# Names aren't unique, so drop the names for this dimension.
indexed = None
new_names.append(indexed)
dim += 1
else:
raise TypeError("Unknown index: %s; %s" % (type(index), index))
dim += 1
obj._index_names = new_names + self._index_names[dim:]
if len(obj._index_names) != len(obj.shape):
raise IndexError("Names don't match object shape: %s != %s" % (
Expand Down Expand Up @@ -250,10 +254,12 @@ def _indices(self, indices):
for i, index in enumerate(indices):
if index is Ellipsis:
out.append(index)
dim = len(self.shape) - len(indices) + i
dim = len(self.shape) - len(indices) + i + 1
elif index is np.newaxis:
out.append(None)
else:
out.append(self._get_index(dim, index))
dim += 1
dim += 1
return tuple(out)
else:
return self._get_index(0, indices)
Expand Down
32 changes: 32 additions & 0 deletions pysc2/lib/named_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,21 @@ def test_single_dimension(self, names):
with self.assertRaises(KeyError):
a["d"] # pylint: disable=pointless-statement

# New axis = None
self.assertArrayEqual(a, [1, 3, 6])
self.assertArrayEqual(a[np.newaxis], [[1, 3, 6]])
self.assertArrayEqual(a[None], [[1, 3, 6]])
self.assertArrayEqual(a[None, :], [[1, 3, 6]])
self.assertArrayEqual(a[:, None], [[1], [3], [6]])
self.assertArrayEqual(a[None, :, None], [[[1], [3], [6]]])
self.assertArrayEqual(a[None, a % 3 == 0, None], [[[3], [6]]])
self.assertArrayEqual(a[None][None], [[[1, 3, 6]]])
self.assertArrayEqual(a[None][0], [1, 3, 6])
self.assertEqual(a[None, 0], 1)
self.assertEqual(a[None, "a"], 1)
self.assertEqual(a[None][0].a, 1)
self.assertEqual(a[None][0, "b"], 3)

# range slicing
self.assertArrayEqual(a[0:2], [1, 3])
self.assertArrayEqual(a[1:3], [3, 6])
Expand Down Expand Up @@ -194,6 +209,22 @@ def test_named_array_multi_first(self):
with self.assertRaises(TypeError):
a[0].a # pylint: disable=pointless-statement

# New axis = None
self.assertArrayEqual(a, [[1, 3], [6, 8]])
self.assertArrayEqual(a[np.newaxis], [[[1, 3], [6, 8]]])
self.assertArrayEqual(a[None], [[[1, 3], [6, 8]]])
self.assertArrayEqual(a[None, :], [[[1, 3], [6, 8]]])
self.assertArrayEqual(a[None, "a"], [[1, 3]])
self.assertArrayEqual(a[:, None], [[[1, 3]], [[6, 8]]])
self.assertArrayEqual(a[None, :, None], [[[[1, 3]], [[6, 8]]]])
self.assertArrayEqual(a[None, 0, None], [[[1, 3]]])
self.assertArrayEqual(a[None, "a", None], [[[1, 3]]])
self.assertArrayEqual(a[None][None], [[[[1, 3], [6, 8]]]])
self.assertArrayEqual(a[None][0], [[1, 3], [6, 8]])
self.assertArrayEqual(a[None][0].a, [1, 3])
self.assertEqual(a[None][0].a[0], 1)
self.assertEqual(a[None][0, "b", 1], 8)

def test_named_array_multi_second(self):
a = named_array.NamedNumpyArray([[1, 3], [6, 8]], [None, ["a", "b"]])
self.assertArrayEqual(a[0], [1, 3])
Expand All @@ -206,6 +237,7 @@ def test_named_array_multi_second(self):
self.assertArrayEqual(a[a % 3 == 0], [3, 6])
with self.assertRaises(TypeError):
a.a # pylint: disable=pointless-statement
self.assertArrayEqual(a[None, :, "a"], [[1, 6]])

def test_masking(self):
a = named_array.NamedNumpyArray([[1, 2, 3, 4], [5, 6, 7, 8]],
Expand Down

0 comments on commit e994d8b

Please sign in to comment.