Skip to content

Commit

Permalink
[sparse] specify operand layouts in cusparse.py
Browse files Browse the repository at this point in the history
Why? This can fix issues when inputs have non-standard layouts

PiperOrigin-RevId: 411110145
  • Loading branch information
Jake VanderPlas authored and jax authors committed Nov 19, 2021
1 parent f08a5a0 commit a93c99d
Showing 1 changed file with 57 additions and 59 deletions.
116 changes: 57 additions & 59 deletions jaxlib/cusparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,28 @@
_ops = xla_client.ops
_Shape = xla_client.Shape

def csr_todense(c, data, indices, indptr, *, shape):
"""CSR to dense matrix."""
def _validate_csr(c, data, indices, indptr, shape):
data_dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(indices).element_type())
nnz, = c.get_shape(data).dimensions()
assert c.get_shape(indices).dimensions() == (nnz,)
assert c.get_shape(indptr).element_type() == index_dtype
assert c.get_shape(indptr).dimensions() == (shape[0] + 1,)
return data_dtype, index_dtype, nnz

def _validate_coo(c, data, row, col, shape):
data_dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(row).element_type())
nnz, = c.get_shape(data).dimensions()
assert c.get_shape(row).dimensions() == (nnz,)
assert c.get_shape(col).element_type() == index_dtype
assert c.get_shape(col).dimensions() == (nnz,)
return data_dtype, index_dtype, nnz

def csr_todense(c, data, indices, indptr, *, shape):
"""CSR to dense matrix."""
data_dtype, index_dtype, nnz = _validate_csr(c, data, indices, indptr, shape)
rows, cols = shape
nnz = c.get_shape(data).dimensions()[0]

buffer_size, opaque = _cusparse.build_csr_todense_descriptor(
data_dtype, index_dtype, rows, cols, nnz)
Expand All @@ -50,10 +65,9 @@ def csr_todense(c, data, indices, indptr, *, shape):
b"cusparse_csr_todense",
operands=(data, indices, indptr),
operand_shapes_with_layout=(
# All are 1D, so no layout necessary
c.get_shape(data),
c.get_shape(indices),
c.get_shape(indptr),
_Shape.array_shape(data_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (rows + 1,), (0,)),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(data_dtype, shape, (1, 0)),
Expand Down Expand Up @@ -98,18 +112,16 @@ def csr_fromdense(c, mat, *, nnz, index_dtype):

def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_dtype=None):
"""CSR matrix/vector multiply."""
dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(indices).element_type())
assert c.get_shape(indptr).element_type() == index_dtype
x_dtype = np.dtype(c.get_shape(x).element_type())
data_dtype, index_dtype, nnz = _validate_csr(c, data, indices, indptr, shape)
rows, cols = shape
nnz, = c.get_shape(data).dimensions()
x_dtype = np.dtype(c.get_shape(x).element_type())
x_shape = c.get_shape(x).dimensions()

if compute_dtype is None:
compute_dtype = dtype
compute_dtype = data_dtype

buffer_size, opaque = _cusparse.build_csr_matvec_descriptor(
dtype, x_dtype, compute_dtype, index_dtype,
data_dtype, x_dtype, compute_dtype, index_dtype,
rows, cols, nnz, transpose)
out_size = cols if transpose else rows

Expand All @@ -118,11 +130,10 @@ def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_d
b"cusparse_csr_matvec",
operands=(data, indices, indptr, x),
operand_shapes_with_layout=(
# All are 1D, so no layout necessary
c.get_shape(data),
c.get_shape(indices),
c.get_shape(indptr),
c.get_shape(x),
_Shape.array_shape(data_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (rows + 1,), (0,)),
_Shape.array_shape(x_dtype, x_shape, (0,))
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(compute_dtype, (out_size,), (0,)),
Expand All @@ -136,20 +147,17 @@ def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_d

def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_dtype=None):
"""CSR from dense matrix."""
dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(indices).element_type())
assert c.get_shape(indptr).element_type() == index_dtype
data_dtype, index_dtype, nnz = _validate_csr(c, data, indices, indptr, shape)
rows, cols = shape
B_dtype = np.dtype(c.get_shape(B).element_type())
B_shape = c.get_shape(B).dimensions()
rows, cols = shape
_, Ccols = B_shape
nnz, = c.get_shape(data).dimensions()

if compute_dtype is None:
compute_dtype = dtype
compute_dtype = data_dtype

buffer_size, opaque = _cusparse.build_csr_matmat_descriptor(
dtype, B_dtype, compute_dtype, index_dtype,
data_dtype, B_dtype, compute_dtype, index_dtype,
rows, cols, Ccols, nnz, transpose)
out_size = cols if transpose else rows

Expand All @@ -158,9 +166,9 @@ def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_d
b"cusparse_csr_matmat",
operands=(data, indices, indptr, B),
operand_shapes_with_layout=(
c.get_shape(data),
c.get_shape(indices),
c.get_shape(indptr),
_Shape.array_shape(data_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (rows + 1,), (0,)),
_Shape.array_shape(B_dtype, B_shape, (1, 0)),
),
shape_with_layout=_Shape.tuple_shape((
Expand All @@ -175,11 +183,8 @@ def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_d

def coo_todense(c, data, row, col, *, shape):
"""COO to dense matrix."""
data_dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(row).element_type())
assert c.get_shape(row).element_type() == index_dtype
data_dtype, index_dtype, nnz = _validate_coo(c, data, row, col, shape)
rows, cols = shape
nnz = c.get_shape(data).dimensions()[0]

buffer_size, opaque = _cusparse.build_coo_todense_descriptor(
data_dtype, index_dtype, rows, cols, nnz)
Expand All @@ -189,10 +194,9 @@ def coo_todense(c, data, row, col, *, shape):
b"cusparse_coo_todense",
operands=(data, row, col),
operand_shapes_with_layout=(
# All are 1D, so no layout necessary
c.get_shape(data),
c.get_shape(row),
c.get_shape(col),
_Shape.array_shape(data_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(data_dtype, shape, (1, 0)),
Expand Down Expand Up @@ -236,18 +240,16 @@ def coo_fromdense(c, mat, *, nnz, index_dtype):

def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=None):
"""COO matrix/vector multiply."""
dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(row).element_type())
assert c.get_shape(row).element_type() == index_dtype
x_dtype = np.dtype(c.get_shape(x).element_type())
data_dtype, index_dtype, nnz = _validate_coo(c, data, row, col, shape)
rows, cols = shape
nnz, = c.get_shape(data).dimensions()
x_dtype = np.dtype(c.get_shape(x).element_type())
x_shape = c.get_shape(x).dimensions()

if compute_dtype is None:
compute_dtype = dtype
compute_dtype = data_dtype

buffer_size, opaque = _cusparse.build_coo_matvec_descriptor(
dtype, x_dtype, compute_dtype, index_dtype,
data_dtype, x_dtype, compute_dtype, index_dtype,
rows, cols, nnz, transpose)
out_size = cols if transpose else rows

Expand All @@ -256,11 +258,10 @@ def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=No
b"cusparse_coo_matvec",
operands=(data, row, col, x),
operand_shapes_with_layout=(
# All are 1D, so no layout necessary
c.get_shape(data),
c.get_shape(row),
c.get_shape(col),
c.get_shape(x),
_Shape.array_shape(data_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(x_dtype, x_shape, (0,)),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(compute_dtype, (out_size,), (0,)),
Expand All @@ -274,20 +275,17 @@ def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=No

def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=None):
"""COO from dense matrix."""
dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(row).element_type())
assert c.get_shape(row).element_type() == index_dtype
data_dtype, index_dtype, nnz = _validate_coo(c, data, row, col, shape)
rows, cols = shape
B_dtype = np.dtype(c.get_shape(B).element_type())
B_shape = c.get_shape(B).dimensions()
rows, cols = shape
_, Ccols = B_shape
nnz, = c.get_shape(data).dimensions()

if compute_dtype is None:
compute_dtype = dtype
compute_dtype = data_dtype

buffer_size, opaque = _cusparse.build_coo_matmat_descriptor(
dtype, B_dtype, compute_dtype, index_dtype,
data_dtype, B_dtype, compute_dtype, index_dtype,
rows, cols, Ccols, nnz, transpose)
out_size = cols if transpose else rows

Expand All @@ -296,9 +294,9 @@ def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=No
b"cusparse_coo_matmat",
operands=(data, row, col, B),
operand_shapes_with_layout=(
c.get_shape(data),
c.get_shape(row),
c.get_shape(col),
_Shape.array_shape(data_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(B_dtype, B_shape, (1, 0)),
),
shape_with_layout=_Shape.tuple_shape((
Expand Down

0 comments on commit a93c99d

Please sign in to comment.