From 4cd67a6914aae56f958253de032cdffe4fd34798 Mon Sep 17 00:00:00 2001 From: Pauli Virtanen Date: Sun, 2 Feb 2014 17:20:16 +0200 Subject: [PATCH] BUG: sparse: fix bugs in dia_matrix.setdiag - hide details of the internal dia_matrix data representation in setdiag(), so that behavior matches the other spmatrix types - fix bugs in resizing self.data, which doesn't always have self.shape[1]==N - add scalar broadcasting behavior to setdiag() for all matrix types --- scipy/sparse/base.py | 51 ++++++++++++++++++++------ scipy/sparse/dia.py | 32 ++++++++++++++-- scipy/sparse/tests/test_base.py | 65 +++++++++++++++++++++++++++++++-- 3 files changed, 129 insertions(+), 19 deletions(-) diff --git a/scipy/sparse/base.py b/scipy/sparse/base.py index e9fb9dbaa0f0..9e00a0547334 100644 --- a/scipy/sparse/base.py +++ b/scipy/sparse/base.py @@ -728,25 +728,52 @@ def diagonal(self): return self.tocsr().diagonal() def setdiag(self, values, k=0): - """Fills the diagonal elements {a_ii} with the values from the - given sequence. If k != 0, fills the off-diagonal elements - {a_{i,i+k}} instead. + """ + Set diagonal or off-diagonal elements of the array. + + Parameters + ---------- + values : array_like + New values of the diagonal elements. + + Values may have any length. If the diagonal is longer than values, + then the remaining diagonal entries will not be set. If values if + longer than the diagonal, then the remaining values are ignored. + + If a scalar value is given, all of the diagonal is set to it. + + k : int, optional + Which off-diagonal to set, corresponding to elements a[i,i+k]. + Default: 0 (the main diagonal). - values may have any length. If the diagonal is longer than values, - then the remaining diagonal entries will not be set. If values if - longer than the diagonal, then the remaining values are ignored. """ M, N = self.shape if (k > 0 and k >= N) or (k < 0 and -k >= M): raise ValueError("k exceeds matrix dimensions") if k < 0: - max_index = min(M+k, N, len(values)) - for i,v in enumerate(values[:max_index]): - self[i - k, i] = v + if np.asarray(values).ndim == 0: + # broadcast + max_index = min(M+k, N) + for i in xrange(max_index): + self[i - k, i] = values + else: + max_index = min(M+k, N, len(values)) + if max_index <= 0: + return + for i,v in enumerate(values[:max_index]): + self[i - k, i] = v else: - max_index = min(M, N-k, len(values)) - for i,v in enumerate(values[:max_index]): - self[i, i + k] = v + if np.asarray(values).ndim == 0: + # broadcast + max_index = min(M, N-k) + for i in xrange(max_index): + self[i, i + k] = values + else: + max_index = min(M, N-k, len(values)) + if max_index <= 0: + return + for i,v in enumerate(values[:max_index]): + self[i, i + k] = v def _process_toarray_args(self, order, out): if out is not None: diff --git a/scipy/sparse/dia.py b/scipy/sparse/dia.py index 5860a0165138..c633219d6281 100644 --- a/scipy/sparse/dia.py +++ b/scipy/sparse/dia.py @@ -188,13 +188,37 @@ def setdiag(self, values, k=0): M, N = self.shape if k <= -M or k >= N: raise ValueError('k exceeds matrix dimensions') + + values = np.asarray(values) + + if values.ndim == 0: + # broadcast + values_n = np.inf + else: + values_n = len(values) + + if k < 0: + n = min(M + k, N, values_n) + min_index = 0 + max_index = n + else: + n = min(M, N - k, values_n) + min_index = k + max_index = k + n + + if values.ndim != 0: + # allow also longer sequences + values = values[:n] + if k in self.offsets: - self.data[self.offsets == k, :] = values + self.data[self.offsets == k, min_index:max_index] = values else: self.offsets = np.append(self.offsets, self.offsets.dtype.type(k)) - self.data = np.vstack((self.data, - np.empty((1, N), dtype=self.data.dtype))) - self.data[-1, :] = values + m = max(max_index, self.data.shape[1]) + data = np.zeros((self.data.shape[0]+1, m), dtype=self.data.dtype) + data[:-1,:self.data.shape[1]] = self.data + data[-1, min_index:max_index] = values + self.data = data setdiag.__doc__ = _data_matrix.setdiag.__doc__ diff --git a/scipy/sparse/tests/test_base.py b/scipy/sparse/tests/test_base.py index 67dc9b5f5b8d..ef52e824ce53 100644 --- a/scipy/sparse/tests/test_base.py +++ b/scipy/sparse/tests/test_base.py @@ -597,14 +597,73 @@ def test_diagonal(self): assert_equal(self.spmatrix(m).diagonal(),diag(m)) def test_setdiag(self): + def dense_setdiag(a, v, k): + v = np.asarray(v) + if k >= 0: + n = min(a.shape[0], a.shape[1] - k) + if v.ndim != 0: + n = min(n, len(v)) + v = v[:n] + i = np.arange(0, n) + j = np.arange(k, k + n) + a[i,j] = v + elif k < 0: + dense_setdiag(a.T, v, -k) + return + + def check_setdiag(a, b, k): + # Check setting diagonal using a scalar, a vector of + # correct length, and too short or too long vectors + for r in [-1, len(np.diag(a, k)), 2, 30]: + if r < 0: + v = int(np.random.randint(1, 20, size=1)) + else: + v = np.random.randint(1, 20, size=r) + + dense_setdiag(a, v, k) + b.setdiag(v, k) + + # check that dense_setdiag worked + d = np.diag(a, k) + if np.asarray(v).ndim == 0: + assert_array_equal(d, v, err_msg=msg + " %d" % (r,)) + else: + n = min(len(d), len(v)) + assert_array_equal(d[:n], v[:n], err_msg=msg + " %d" % (r,)) + # check that sparse setdiag worked + assert_array_equal(b.A, a, err_msg=msg + " %d" % (r,)) + + # comprehensive test + np.random.seed(1234) + for dtype in [np.int8, np.float64]: + for m in [0, 1, 3, 10]: + for n in [0, 1, 3, 10]: + for k in range(-m+1, n-1): + msg = repr((dtype, m, n, k)) + a = np.zeros((m, n), dtype=dtype) + b = self.spmatrix((m, n), dtype=dtype) + + check_setdiag(a, b, k) + + # check overwriting etc + for k2 in np.random.randint(-m+1, n-1, size=12): + check_setdiag(a, b, k2) + + + # simpler test case m = self.spmatrix(np.eye(3)) values = [3, 2, 1] - # it is out of limits assert_raises(ValueError, m.setdiag, values, k=4) m.setdiag(values) assert_array_equal(m.diagonal(), values) - - # test setting offdiagonals (k!=0) + m.setdiag(values, k=1) + assert_array_equal(m.A, np.array([[3, 3, 0], + [0, 2, 2], + [0, 0, 1]])) + m.setdiag(values, k=-2) + assert_array_equal(m.A, np.array([[3, 3, 0], + [0, 2, 2], + [3, 0, 1]])) m.setdiag((9,), k=2) assert_array_equal(m.A[0,2], 9) m.setdiag((9,), k=-2)