Skip to content

Commit

Permalink
Merge pull request scipy#3272 from pv/fix-setdiag
Browse files Browse the repository at this point in the history
BUG: sparse: fix bugs in dia_matrix.setdiag
  • Loading branch information
rgommers committed Feb 25, 2014
2 parents 047ffca + 4cd67a6 commit c43260b
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 19 deletions.
51 changes: 39 additions & 12 deletions scipy/sparse/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,25 +728,52 @@ def diagonal(self):
return self.tocsr().diagonal()

def setdiag(self, values, k=0):
"""Fills the diagonal elements {a_ii} with the values from the
given sequence. If k != 0, fills the off-diagonal elements
{a_{i,i+k}} instead.
"""
Set diagonal or off-diagonal elements of the array.
Parameters
----------
values : array_like
New values of the diagonal elements.
Values may have any length. If the diagonal is longer than values,
then the remaining diagonal entries will not be set. If values if
longer than the diagonal, then the remaining values are ignored.
If a scalar value is given, all of the diagonal is set to it.
k : int, optional
Which off-diagonal to set, corresponding to elements a[i,i+k].
Default: 0 (the main diagonal).
values may have any length. If the diagonal is longer than values,
then the remaining diagonal entries will not be set. If values if
longer than the diagonal, then the remaining values are ignored.
"""
M, N = self.shape
if (k > 0 and k >= N) or (k < 0 and -k >= M):
raise ValueError("k exceeds matrix dimensions")
if k < 0:
max_index = min(M+k, N, len(values))
for i,v in enumerate(values[:max_index]):
self[i - k, i] = v
if np.asarray(values).ndim == 0:
# broadcast
max_index = min(M+k, N)
for i in xrange(max_index):
self[i - k, i] = values
else:
max_index = min(M+k, N, len(values))
if max_index <= 0:
return
for i,v in enumerate(values[:max_index]):
self[i - k, i] = v
else:
max_index = min(M, N-k, len(values))
for i,v in enumerate(values[:max_index]):
self[i, i + k] = v
if np.asarray(values).ndim == 0:
# broadcast
max_index = min(M, N-k)
for i in xrange(max_index):
self[i, i + k] = values
else:
max_index = min(M, N-k, len(values))
if max_index <= 0:
return
for i,v in enumerate(values[:max_index]):
self[i, i + k] = v

def _process_toarray_args(self, order, out):
if out is not None:
Expand Down
32 changes: 28 additions & 4 deletions scipy/sparse/dia.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,37 @@ def setdiag(self, values, k=0):
M, N = self.shape
if k <= -M or k >= N:
raise ValueError('k exceeds matrix dimensions')

values = np.asarray(values)

if values.ndim == 0:
# broadcast
values_n = np.inf
else:
values_n = len(values)

if k < 0:
n = min(M + k, N, values_n)
min_index = 0
max_index = n
else:
n = min(M, N - k, values_n)
min_index = k
max_index = k + n

if values.ndim != 0:
# allow also longer sequences
values = values[:n]

if k in self.offsets:
self.data[self.offsets == k, :] = values
self.data[self.offsets == k, min_index:max_index] = values
else:
self.offsets = np.append(self.offsets, self.offsets.dtype.type(k))
self.data = np.vstack((self.data,
np.empty((1, N), dtype=self.data.dtype)))
self.data[-1, :] = values
m = max(max_index, self.data.shape[1])
data = np.zeros((self.data.shape[0]+1, m), dtype=self.data.dtype)
data[:-1,:self.data.shape[1]] = self.data
data[-1, min_index:max_index] = values
self.data = data

setdiag.__doc__ = _data_matrix.setdiag.__doc__

Expand Down
65 changes: 62 additions & 3 deletions scipy/sparse/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,14 +610,73 @@ def test_diagonal(self):
assert_equal(self.spmatrix(m).diagonal(),diag(m))

def test_setdiag(self):
def dense_setdiag(a, v, k):
v = np.asarray(v)
if k >= 0:
n = min(a.shape[0], a.shape[1] - k)
if v.ndim != 0:
n = min(n, len(v))
v = v[:n]
i = np.arange(0, n)
j = np.arange(k, k + n)
a[i,j] = v
elif k < 0:
dense_setdiag(a.T, v, -k)
return

def check_setdiag(a, b, k):
# Check setting diagonal using a scalar, a vector of
# correct length, and too short or too long vectors
for r in [-1, len(np.diag(a, k)), 2, 30]:
if r < 0:
v = int(np.random.randint(1, 20, size=1))
else:
v = np.random.randint(1, 20, size=r)

dense_setdiag(a, v, k)
b.setdiag(v, k)

# check that dense_setdiag worked
d = np.diag(a, k)
if np.asarray(v).ndim == 0:
assert_array_equal(d, v, err_msg=msg + " %d" % (r,))
else:
n = min(len(d), len(v))
assert_array_equal(d[:n], v[:n], err_msg=msg + " %d" % (r,))
# check that sparse setdiag worked
assert_array_equal(b.A, a, err_msg=msg + " %d" % (r,))

# comprehensive test
np.random.seed(1234)
for dtype in [np.int8, np.float64]:
for m in [0, 1, 3, 10]:
for n in [0, 1, 3, 10]:
for k in range(-m+1, n-1):
msg = repr((dtype, m, n, k))
a = np.zeros((m, n), dtype=dtype)
b = self.spmatrix((m, n), dtype=dtype)

check_setdiag(a, b, k)

# check overwriting etc
for k2 in np.random.randint(-m+1, n-1, size=12):
check_setdiag(a, b, k2)


# simpler test case
m = self.spmatrix(np.eye(3))
values = [3, 2, 1]
# it is out of limits
assert_raises(ValueError, m.setdiag, values, k=4)
m.setdiag(values)
assert_array_equal(m.diagonal(), values)

# test setting offdiagonals (k!=0)
m.setdiag(values, k=1)
assert_array_equal(m.A, np.array([[3, 3, 0],
[0, 2, 2],
[0, 0, 1]]))
m.setdiag(values, k=-2)
assert_array_equal(m.A, np.array([[3, 3, 0],
[0, 2, 2],
[3, 0, 1]]))
m.setdiag((9,), k=2)
assert_array_equal(m.A[0,2], 9)
m.setdiag((9,), k=-2)
Expand Down

0 comments on commit c43260b

Please sign in to comment.