Skip to content

Commit 536d175

Browse files
committed
Don't allow shape mismatches for empty arrays
Instead change callers to have empty arrays of the right shape.
1 parent c0e09d5 commit 536d175

File tree

6 files changed

+31
-13
lines changed

6 files changed

+31
-13
lines changed

lib/matplotlib/cbook.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2244,6 +2244,21 @@ def _reshape_2D(X):
22442244
return X
22452245

22462246

2247+
def ensure_3d(arr):
2248+
"""
2249+
Return a version of arr with ndim==3, with extra dimensions added
2250+
at the end of arr.shape as needed.
2251+
"""
2252+
arr = np.asanyarray(arr)
2253+
if arr.ndim == 1:
2254+
arr = arr[:, None, None]
2255+
elif arr.ndim == 2:
2256+
arr = arr[:, :, None]
2257+
elif arr.ndim > 3 or arr.ndim < 1:
2258+
raise ValueError("cannot convert arr to 3-dimensional")
2259+
return arr
2260+
2261+
22472262
def violin_stats(X, method, points=100):
22482263
'''
22492264
Returns a list of dictionaries of data which can be used to draw a series

lib/matplotlib/collections.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,7 @@ class Collection(artist.Artist, cm.ScalarMappable):
8080
# _offsets must be a Nx2 array!
8181
_offsets.shape = (0, 2)
8282
_transOffset = transforms.IdentityTransform()
83-
_transforms = []
84-
85-
83+
_transforms = np.empty((0, 3, 3))
8684

8785
def __init__(self,
8886
edgecolors=None,
@@ -1515,7 +1513,7 @@ def __init__(self, widths, heights, angles, units='points', **kwargs):
15151513
self._angles = np.asarray(angles).ravel() * (np.pi / 180.0)
15161514
self._units = units
15171515
self.set_transform(transforms.IdentityTransform())
1518-
self._transforms = []
1516+
self._transforms = np.empty((0, 3, 3))
15191517
self._paths = [mpath.Path.unit_circle()]
15201518

15211519
def _set_transforms(self):

lib/matplotlib/path.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from numpy import ma
2525

2626
from matplotlib import _path
27-
from matplotlib.cbook import simple_linear_interpolation, maxdict
27+
from matplotlib.cbook import simple_linear_interpolation, maxdict, ensure_3d
2828
from matplotlib import rcParams
2929

3030

@@ -988,7 +988,8 @@ def get_path_collection_extents(
988988
if len(paths) == 0:
989989
raise ValueError("No paths provided")
990990
return Bbox.from_extents(*_path.get_path_collection_extents(
991-
master_transform, paths, transforms, offsets, offset_transform))
991+
master_transform, paths, ensure_3d(transforms),
992+
offsets, offset_transform))
992993

993994

994995
def get_paths_extents(paths, transforms=[]):

lib/matplotlib/tests/test_cbook.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,3 +376,11 @@ def test_step_fails():
376376
np.arange(12))
377377
assert_raises(ValueError, cbook._step_validation,
378378
np.arange(12), np.arange(3))
379+
380+
381+
def test_ensure_3d():
382+
assert_array_equal([[[1]], [[2]], [[3]]],
383+
cbook.ensure_3d([1, 2, 3]))
384+
assert_array_equal([[[1], [2]], [[3], [4]]],
385+
cbook.ensure_3d([[1, 2], [3, 4]]))
386+
assert_raises(ValueError, cbook.ensure_3d, [[[[1]]]])

lib/matplotlib/transforms.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from sets import Set as set
4949

5050
from .path import Path
51+
from .cbook import ensure_3d
5152

5253
DEBUG = False
5354
# we need this later, but this is very expensive to set up
@@ -666,7 +667,8 @@ def count_overlaps(self, bboxes):
666667
667668
bboxes is a sequence of :class:`BboxBase` objects
668669
"""
669-
return count_bboxes_overlapping_bbox(self, [np.array(x) for x in bboxes])
670+
return count_bboxes_overlapping_bbox(
671+
self, ensure_3d([np.array(x) for x in bboxes]))
670672

671673
def expanded(self, sw, sh):
672674
"""

src/numpy_cpp.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -448,12 +448,6 @@ class array_view : public detail::array_view_accessors<array_view, T, ND>
448448
return 1;
449449
}
450450
}
451-
if (PyArray_NDIM(tmp) > 0 && PyArray_DIM(tmp, 0) == 0) {
452-
// accept dimension mismatch for empty arrays
453-
Py_XDECREF(m_arr);
454-
m_arr = tmp;
455-
return 1;
456-
}
457451
if (PyArray_NDIM(tmp) != ND) {
458452
PyErr_Format(PyExc_ValueError,
459453
"Expected %d-dimensional array, got %d",

0 commit comments

Comments
 (0)