Skip to content

Commit

Permalink
ATOL/RTOL of 1e-13 seem OK; added also comparison vs Numpy/LAPACK
Browse files Browse the repository at this point in the history
  • Loading branch information
enzbus committed Nov 24, 2024
1 parent 23aa2b7 commit 1c2aa1f
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions pyspqr/tests/test_pyspqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,56 @@
ABS_ACCURACY = 1e-13
REL_ACCURACY = 1e-13

# We also compare accuracy vs dense QR Numpy/LAPACK accuracy
ACCURACY_LOSS_DENSE = 4.

class TestSuiteSparseQR(TestCase):
"""Unit tests for pyspqr."""

def _check_fwd_mult(self, A, Q, R, E):
def _check_fwd_mult(self, A, Q, R, E, Adense, Qdense, Rdense):
"""Check forward multiplication."""
x = np.random.randn(A.shape[1])
aprod = A @ x
qrprod = Q @ (R @ (E @ x))

adenseprod = Adense @ x
qrdenseprod = Qdense @ (Rdense @ x)

self.assertLessEqual(
np.linalg.norm(aprod - qrprod),
ACCURACY_LOSS_DENSE * np.linalg.norm(adenseprod - qrdenseprod))

self.assertLessEqual(
np.max(np.abs(aprod - qrprod)),
ACCURACY_LOSS_DENSE * np.max(np.abs(adenseprod - qrdenseprod)))

self.assertTrue(
np.allclose(aprod, qrprod, atol=ABS_ACCURACY, rtol=REL_ACCURACY))

def _check_bwd_mult(self, A, Q, R, E):
def _check_bwd_mult(self, A, Q, R, E, Adense, Qdense, Rdense):
"""Check backward multiplication."""
x = np.random.randn(A.shape[0])
atprod = A.T @ x
qrtprod = E.T @ (R.T @ (Q.T @ x))

atdenseprod = Adense.T @ x
qrtdenseprod = Rdense.T @ (Qdense.T @ x)

self.assertLessEqual(
np.linalg.norm(atprod - qrtprod),
ACCURACY_LOSS_DENSE * np.linalg.norm(atdenseprod - qrtdenseprod))

self.assertLessEqual(
np.max(np.abs(atprod - qrtprod)),
ACCURACY_LOSS_DENSE * np.max(np.abs(atdenseprod - qrtdenseprod)))

self.assertTrue(
np.allclose(atprod, qrtprod, atol=ABS_ACCURACY, rtol=REL_ACCURACY))

def _qr_check(self, A):
"""Base test for given matrix."""
Adense = A.todense().A
Qdense, Rdense = np.linalg.qr(Adense)
for ordering in [
'FIXED',
'NATURAL',
Expand All @@ -59,8 +87,8 @@ def _qr_check(self, A):
]:
with self.subTest(ordering=ordering):
Q, R, E = qr(A, ordering)
self._check_fwd_mult(A, Q, R, E)
self._check_bwd_mult(A, Q, R, E)
self._check_fwd_mult(A, Q, R, E, Adense, Qdense, Rdense)
self._check_bwd_mult(A, Q, R, E, Adense, Qdense, Rdense)

def test_corner(self):
"""Test with some corner cases."""
Expand Down

0 comments on commit 1c2aa1f

Please sign in to comment.