Skip to content

Commit

Permalink
Merge pull request chainer#2114 from pfnet/numpy1.12-slice
Browse files Browse the repository at this point in the history
Change behavior of complete_slice to fit numpy 1.12.0
  • Loading branch information
unnonouno authored Jan 18, 2017
2 parents 64018f7 + 9b57198 commit fd20b95
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 14 deletions.
31 changes: 22 additions & 9 deletions cupy/core/internal.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ cpdef vector.vector[Py_ssize_t] infer_unknown_dimension(
return ret


@cython.profile(False)
cpdef int _extract_slice_element(x) except *:
try:
return x.__index__()
except AttributeError:
return int(x)


@cython.profile(False)
cpdef slice complete_slice(slice slc, Py_ssize_t dim):
cpdef Py_ssize_t start=0, stop=0, step=0
Expand All @@ -149,31 +157,36 @@ cpdef slice complete_slice(slice slc, Py_ssize_t dim):
step = 1
else:
try:
step = int(slc.step)
step = _extract_slice_element(slc.step)
except TypeError:
raise IndexError(
'slice.step must be int or None: {}'.format(slc))
raise TypeError(
'slice.step must be int or None or have __index__ method: '
'{}'.format(slc))

if step == 0:
raise ValueError('Slice step must be nonzero.')

start_none = slc.start is None
if not start_none:
try:
start = int(slc.start)
start = _extract_slice_element(slc.start)
except TypeError:
raise IndexError(
'slice.start must be int or None: {}'.format(slc))
raise TypeError(
'slice.start must be int or None or have __index__ method: '
'{}'.format(slc))

if start < 0:
start += dim

stop_none = slc.stop is None
if not stop_none:
try:
stop = int(slc.stop)
stop = _extract_slice_element(slc.stop)
except TypeError:
raise IndexError(
'slice.stop must be int or None: {}'.format(slc))
raise TypeError(
'slice.stop must be int or None or have __index__ method: '
'{}'.format(slc))

if stop < 0:
stop += dim

Expand Down
11 changes: 6 additions & 5 deletions tests/cupy_tests/core_tests/test_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,24 +171,25 @@ def test_complete_slice(self):
slice(*self.expect))


@testing.with_requires('numpy>=1.12')
class TestCompleteSliceError(unittest.TestCase):

def test_invalid_step_value(self):
with self.assertRaises(ValueError):
internal.complete_slice(slice(1, 1, 0), 1)

def test_invalid_step_type(self):
with self.assertRaises(IndexError):
with self.assertRaises(TypeError):
internal.complete_slice(slice(1, 1, (1, 2)), 1)

def test_invalid_start_type(self):
with self.assertRaises(IndexError):
with self.assertRaises(TypeError):
internal.complete_slice(slice((1, 2), 1, 1), 1)
with self.assertRaises(IndexError):
with self.assertRaises(TypeError):
internal.complete_slice(slice((1, 2), 1, -1), 1)

def test_invalid_stop_type(self):
with self.assertRaises(IndexError):
with self.assertRaises(TypeError):
internal.complete_slice(slice((1, 2), 1, 1), 1)
with self.assertRaises(IndexError):
with self.assertRaises(TypeError):
internal.complete_slice(slice((1, 2), 1, -1), 1)
1 change: 1 addition & 0 deletions tests/cupy_tests/core_tests/test_ndarray_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def test_getitem(self, xp, dtype):
{'shape': (2, 3, 4), 'transpose': None,
'indexes': (slice(None, None, (0, 0)), )},
)
@testing.with_requires('numpy>=1.12.0')
@testing.gpu
class TestArrayInvalidIndex(unittest.TestCase):

Expand Down

0 comments on commit fd20b95

Please sign in to comment.