|
| 1 | +"""Tools for working with covariance matrices or stacks thereof.""" |
| 2 | + |
| 3 | +# Copyright (c) 2015-2025 Syntrogi Inc. dba Intheon. |
| 4 | + |
| 5 | +# Permission is hereby granted, free of charge, to any person obtaining a copy |
| 6 | +# of this software and associated documentation files (the "Software"), to deal |
| 7 | +# in the Software without restriction, including without limitation the rights |
| 8 | +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 9 | +# copies of the Software, and to permit persons to whom the Software is |
| 10 | +# furnished to do so, subject to the following conditions: |
| 11 | + |
| 12 | +# The above copyright notice and this permission notice shall be included in all |
| 13 | +# copies or substantial portions of the Software. |
| 14 | + |
| 15 | +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 16 | +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 17 | +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 18 | +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 19 | +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 20 | +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 21 | +# SOFTWARE. |
| 22 | + |
| 23 | + |
| 24 | +import logging |
| 25 | + |
| 26 | +import numpy as np |
| 27 | + |
| 28 | +logger = logging.getLogger(__name__) |
| 29 | + |
| 30 | +__all__ = ['cov_mean', 'cov_logm', 'cov_expm', 'cov_powm', 'cov_sqrtm', 'cov_rsqrtm', 'cov_sqrtm2', 'cov_shrinkage'] |
| 31 | + |
| 32 | + |
| 33 | +def diag_nd(M): |
| 34 | + """Like np.diag, but in case of a ...,N, returns a ...,N,N array of diag matrices.""" |
| 35 | + *dims, N = M.shape |
| 36 | + if dims: |
| 37 | + cat = np.concatenate([np.diag(d) for d in M.reshape((-1, N))]) |
| 38 | + return np.reshape(cat, dims + [N, N]) |
| 39 | + else: |
| 40 | + return np.diag(M) |
| 41 | + |
| 42 | + |
| 43 | +def cov_logm(C): |
| 44 | + """Calculate the matrix logarithm of a covariance matrix or ...,N,N array.""" |
| 45 | + D, V = np.linalg.eigh(C) |
| 46 | + return V @ diag_nd(np.log(D)) @ V.swapaxes(-2, -1) |
| 47 | + |
| 48 | + |
| 49 | +def cov_expm(C): |
| 50 | + """Calculate the matrix exponent of a covariance matrix or ...,N,N array.""" |
| 51 | + D, V = np.linalg.eigh(C) |
| 52 | + return V @ diag_nd(np.exp(D)) @ V.swapaxes(-2, -1) |
| 53 | + |
| 54 | + |
| 55 | +def cov_powm(C, exp): |
| 56 | + """Calculate a matrix power of a covariance matrix or ...,N,N array.""" |
| 57 | + D, V = np.linalg.eigh(C) |
| 58 | + return V @ diag_nd(D**exp) @ V.swapaxes(-2, -1) |
| 59 | + |
| 60 | + |
| 61 | +def cov_sqrtm(C): |
| 62 | + """Calculate the matrix square root of a covariance matrix or ...,N,N array.""" |
| 63 | + D, V = np.linalg.eigh(C) |
| 64 | + return V @ diag_nd(np.sqrt(D)) @ V.swapaxes(-2, -1) |
| 65 | + |
| 66 | + |
| 67 | +def cov_rsqrtm(C): |
| 68 | + """Calculate the matrix reciprocal square root of a covariance matrix or ...,N,N array.""" |
| 69 | + D, V = np.linalg.eigh(C) |
| 70 | + return V @ diag_nd(1./np.sqrt(D)) @ V.swapaxes(-2, -1) |
| 71 | + |
| 72 | + |
| 73 | +def cov_sqrtm2(C): |
| 74 | + """Calculate the matrix square root, and its reciprocal, for a covariance matrix or ...,N,N array.""" |
| 75 | + D, V = np.linalg.eigh(C) |
| 76 | + sqrtD = np.sqrt(D) |
| 77 | + return V @ diag_nd(sqrtD) @ V.swapaxes(-2, -1), V @ diag_nd(1./sqrtD) @ V.swapaxes(-2, -1) |
| 78 | + |
| 79 | + |
| 80 | +def cov_mean(X, *, weights=None, robust=False, iters=50, tol=1e-5, huber=0, |
| 81 | + nancheck=False, verbose=False): |
| 82 | + """Calculate the (weighted) average of a set of covariance matrices on the |
| 83 | + manifold of SPD matrices, optionally robustly using the geometric median or |
| 84 | + Huber mean. |
| 85 | +
|
| 86 | + Args: |
| 87 | + X: a M,N,N array of covariance matrices |
| 88 | + weights: optionally a vector of sample weights (can be unnormalized) |
| 89 | + robust: whether to use a robust estimator |
| 90 | + iters: maximum number of iterations |
| 91 | + huber: huber threshold (delta parameter); can be set to |
| 92 | + * None: use regular least-squares solution |
| 93 | + * 0: use geometric / l1 median |
| 94 | + * >0: use a Huber mean with the given value as the threshold |
| 95 | + tol: tolerance for convergence check |
| 96 | + nancheck: check for NaNs |
| 97 | + verbose: generate verbose output (will print deviations in huber=None mode) |
| 98 | +
|
| 99 | + Returns: |
| 100 | + the N,N mean covariance matrix |
| 101 | + """ |
| 102 | + # This algorithm is based on: |
| 103 | + # [1] Ostresh et al., 1978, "On the Convergence of a Class of Iterative Methods for Solving the Weber Location Problem" |
| 104 | + # [2] Fletcher et al., 2004, "Principal Geodesic Analysis on Symmetric Spaces: Statistics of Diffusion Tensors" |
| 105 | + # [3] Fletcher et al. 2010, "The geometric median on Riemannian manifolds with application to robust atlas estimation" |
| 106 | + # [4] Barachant et al., 2014, "Multiclass Brain-Computer Interface Classification by Riemannian Geometry" |
| 107 | + weights = np.ones(len(X)) if weights is None else np.asarray(weights) |
| 108 | + scales = weights |
| 109 | + |
| 110 | + mu = np.sum(X * weights[:, None, None], axis=0)/np.sum(weights) |
| 111 | + # step size and divergence check threshold |
| 112 | + step, thresh = 1.0, 1e20 |
| 113 | + for i in range(iters): |
| 114 | + mu_sqrt, mu_rsqrt = cov_sqrtm2(mu) |
| 115 | + # linearize around mu (this would be the tangent space, but we omit |
| 116 | + # the pre/post-multiplied mu_sqrt terms since they cancel in both |
| 117 | + # the scale calculation and the exponential map) |
| 118 | + Xt = cov_logm(mu_rsqrt @ X @ mu_rsqrt) |
| 119 | + # geometric-median correction (downweight each pt by its riemannian |
| 120 | + # distance from mu, which we calc here after linearization) |
| 121 | + if robust: |
| 122 | + # deviations/errors per sample |
| 123 | + d = np.sqrt(np.sum(np.square(Xt), axis=(-2, -1))) |
| 124 | + # apply robust scale factor to provided sample weights |
| 125 | + if huber is None: |
| 126 | + scales = weights |
| 127 | + if verbose: |
| 128 | + logger.info(f"median deviations: {np.median(d)}") |
| 129 | + elif huber == 0: |
| 130 | + scales = weights / d |
| 131 | + else: |
| 132 | + w = np.where(d <= huber, 1, huber / d) |
| 133 | + scales = weights * w |
| 134 | + # get update Jacobian (np.average takes care of renormalization) |
| 135 | + J = np.sum(Xt * scales[:, None, None], axis=0)/np.sum(scales) |
| 136 | + # apply update on manifold |
| 137 | + mu = mu_sqrt @ cov_expm(step * J) @ mu_sqrt |
| 138 | + # convergence checks |
| 139 | + Jnorm = np.sqrt(np.sum(np.square(J))) |
| 140 | + if Jnorm < tol or step < tol: |
| 141 | + break |
| 142 | + h = step * Jnorm |
| 143 | + if h < thresh: |
| 144 | + # exponentially decaying learning rate |
| 145 | + step *= 0.95 |
| 146 | + thresh = h |
| 147 | + else: |
| 148 | + # prevent blow-up |
| 149 | + step /= 2 |
| 150 | + if nancheck and np.any(np.isnan(mu)): |
| 151 | + raise RuntimeError("NaNs occurred in cov_mean()") |
| 152 | + return mu |
| 153 | + |
| 154 | + |
| 155 | +def cov_shrinkage(cov, shrinkage=0, *, target='eye'): |
| 156 | + """Regularize the given covariance matrix or stack of matrices using shrinkage. |
| 157 | +
|
| 158 | + Args: |
| 159 | + cov: the covariance matrix (N,N) or stack of matrices (...,N,N). |
| 160 | + shrinkage: degree of shrinkage, between 0 and 1 |
| 161 | + target: target matrix to shrink towards; can be: |
| 162 | + 'eye': the identity matrix (classic shrinkage; good for small values |
| 163 | + of shrinkage) |
| 164 | + 'scaled-eye': the identity matrix, scaled to the average variance |
| 165 | + of the data (can be practical when shrinkage degree is large, since |
| 166 | + otherwise whitening will not have unit variance) |
| 167 | + 'diag': the diagonal of the covariance matrix (diagonal shrinkage) |
| 168 | +
|
| 169 | + Returns: |
| 170 | + the regularized covariance matrix or stack of matrices. |
| 171 | + """ |
| 172 | + if not shrinkage: |
| 173 | + return cov # early exit |
| 174 | + |
| 175 | + N = cov.shape[-1] |
| 176 | + |
| 177 | + if target == 'eye': |
| 178 | + # create a stack of identity matrices matching cov's shape |
| 179 | + eye_target = np.zeros_like(cov) |
| 180 | + eye_target[..., range(N), range(N)] = 1 |
| 181 | + elif target == 'scaled-eye': |
| 182 | + # calculate trace for each matrix in the stack (or single matrix) |
| 183 | + # trace_cov will have shape cov.shape[:-2] or be scalar if cov is 2D |
| 184 | + trace_cov = np.trace(cov, axis1=-2, axis2=-1) |
| 185 | + scale = trace_cov / N |
| 186 | + |
| 187 | + # create a base stack of identity matrices |
| 188 | + eye_base = np.zeros_like(cov) |
| 189 | + eye_base[..., range(N), range(N)] = 1 |
| 190 | + |
| 191 | + # apply scaling |
| 192 | + scale_val = scale |
| 193 | + if cov.ndim > 2: |
| 194 | + scale_val = scale[..., np.newaxis, np.newaxis] |
| 195 | + eye_target = eye_base * scale_val |
| 196 | + elif target == 'diag': |
| 197 | + # get the main diagonal of each matrix in the stack |
| 198 | + main_diagonals = np.diagonal(cov, axis1=-2, axis2=-1) |
| 199 | + # create a stack of diagonal matrices |
| 200 | + eye_target = diag_nd(main_diagonals) |
| 201 | + else: |
| 202 | + raise ValueError(f'Unsupported shrinkage target: {target}') |
| 203 | + |
| 204 | + cov_regu = shrinkage * eye_target + (1 - shrinkage) * cov |
| 205 | + return cov_regu |
0 commit comments