Skip to content

Commit 07abf9d

Browse files
committed
Added support for a riemannian variant of clean_asr
- also added covariance utilities and is_debug function to streamline targeted testing
1 parent 287e210 commit 07abf9d

File tree

6 files changed

+265
-37
lines changed

6 files changed

+265
-37
lines changed

src/eegprep/clean_asr.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Dict, Any, Optional, Union, Tuple
2+
from typing import Dict, Any, Optional, Union, Tuple, Optional
33
from copy import deepcopy
44

55
import numpy as np
@@ -22,7 +22,7 @@ def clean_asr(
2222
ref_tolerances: Union[Tuple[float, float], str] = (-3.5, 5.5),
2323
ref_wndlen: Union[float, str] = 1.0,
2424
use_gpu: bool = False,
25-
useriemannian: bool = False,
25+
useriemannian: Optional[str] = None,
2626
maxmem: Optional[int] = 64
2727
) -> Dict[str, Any]:
2828
"""Run the Artifact Subspace Reconstruction (ASR) method on EEG data.
@@ -55,7 +55,9 @@ def clean_asr(
5555
for a channel to be considered 'bad' during calibration data selection. Default: (-3.5, 5.5). Use 'off' to disable.
5656
ref_wndlen (Union[float, str], optional): Window length in seconds for calibration data selection granularity. Default: 1.0. Use 'off' to disable.
5757
use_gpu (bool, optional): Whether to try using GPU (requires compatible hardware and libraries, currently ignored). Default: False.
58-
useriemannian (bool, optional): Whether to use Riemannian ASR variant (NOT IMPLEMENTED). Default: False.
58+
useriemannian (str, optional): Option to use a Riemannian ASR variant. Can be set to 'calib' to use a Riemannian estimate
59+
at calibration time; this make somewhat different statistical tradeoffs than the default, resulting in a somewhat different
60+
baseline rejection threshold; as a result it is suggested to visually check results and adjust the cutoff as needed. Default: None (disabled).
5961
maxmem (Optional[int], optional): Maximum memory in MB (passed to asr_calibrate/process, but chunking based on it is not implemented in Python port). Default: 64.
6062
6163
Returns:
@@ -66,8 +68,6 @@ def clean_asr(
6668
ImportError: If automatic calibration data selection is needed (`ref_maxbadchannels` is float) but `clean_windows` cannot be imported.
6769
ValueError: If input arguments are invalid or calibration fails critically.
6870
"""
69-
if useriemannian:
70-
raise NotImplementedError("The Riemannian ASR variant is not implemented in this Python port.")
7171

7272
if 'data' not in EEG or 'srate' not in EEG or 'nbchan' not in EEG:
7373
raise ValueError("EEG dictionary must contain 'data', 'srate', and 'nbchan'.")
@@ -128,14 +128,14 @@ def clean_asr(
128128
# The Python asr_calibrate uses its own defaults for blocksize, filters, etc.
129129
# We only pass the core parameters specified in the clean_asr call signature.
130130
try:
131-
state = asr_calibrate(ref_section_data, srate, cutoff=cutoff, maxmem=maxmem)
131+
state = asr_calibrate(ref_section_data, srate, cutoff=cutoff, maxmem=maxmem, useriemannian=useriemannian)
132132
except ValueError as e:
133133
# Catch specific errors like not enough calibration data
134134
raise ValueError(f"ASR calibration failed: {e}")
135-
except Exception as e:
136-
# Catch unexpected errors during calibration
137-
logger.exception("An unexpected error occurred during ASR calibration.")
138-
raise RuntimeError(f"ASR calibration failed unexpectedly: {e}")
135+
# except Exception as e:
136+
# # Catch unexpected errors during calibration
137+
# logger.exception("An unexpected error occurred during ASR calibration.")
138+
# raise RuntimeError(f"ASR calibration failed unexpectedly: {e}")
139139

140140
del ref_section_data # Free memory
141141

src/eegprep/utils/asr.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,17 @@
44
import scipy.signal
55
import scipy.linalg
66

7-
from .stats import block_geometric_median, fit_eeg_distribution
7+
from .stats import geometric_median, fit_eeg_distribution
88
from .sigproc import moving_average
9+
from .covariance import cov_mean, cov_shrinkage
10+
911

1012
logger = logging.getLogger(__name__)
1113

1214

1315
def asr_calibrate(X, srate, cutoff=None, blocksize=None, B=None, A=None,
1416
window_len=None, window_overlap=None, max_dropout_fraction=None,
15-
min_clean_fraction=None, maxmem=None):
17+
min_clean_fraction=None, maxmem=None, useriemannian=None):
1618
"""Calibration function for the Artifact Subspace Reconstruction (ASR) method.
1719
1820
State = asr_calibrate(Data, SamplingRate, Cutoff, BlockSize, FilterB, FilterA, WindowLength, WindowOverlap, MaxDropoutFraction, MinCleanFraction, MaxMemory)
@@ -56,6 +58,10 @@ def asr_calibrate(X, srate, cutoff=None, blocksize=None, B=None, A=None,
5658
max_dropout_fraction (float, optional): Maximum fraction (0-1) of windows subject to dropouts. Default: 0.1.
5759
min_clean_fraction (float, optional): Minimum fraction (0-1) of windows that must be clean. Default: 0.25.
5860
maxmem (int, optional): Maximum memory in MB (for very large data/many channels). Default: 64.
61+
useriemannian (str, optional): Option to use a Riemannian ASR variant. Can be set to 'calib' to use a Riemannian estimate
62+
at calibration time; this make somewhat different statistical tradeoffs than the default, resulting in a potentially
63+
different baseline rejection threshold; as a result it is suggested to visually check results and adjust
64+
the cutoff as needed. Default: None (disabled).
5965
6066
Returns:
6167
dict: State dictionary containing calibration results ('M', 'T') and filter parameters ('B', 'A', 'sos', 'iir_state')
@@ -168,22 +174,26 @@ def asr_calibrate(X, srate, cutoff=None, blocksize=None, B=None, A=None,
168174
# Average the accumulated covariances
169175
U /= blocksize
170176

171-
# Reshape for geometric median calculation
172-
U_reshaped = U.reshape(C * C, -1).T # Shape: (num_blocks, C*C)
173-
174-
# Calculate the geometric median of covariance matrices
175-
logger.info("Calculating robust geometric median covariance...")
176-
med = block_geometric_median(U_reshaped)
177-
178-
# Handle NaN cases (can happen with single observation or degenerate data)
177+
# compute a robust average of the covariance matrices
178+
med = None
179+
if useriemannian in ('calib', 'all', True):
180+
logger.info("Calculating Riemannian geometric median covariance...")
181+
U = U.transpose(2, 0, 1)
182+
# small amount of shrinkage to prevent singularities
183+
U = cov_shrinkage(U, 1e-4, target='scaled-eye')
184+
med = cov_mean(U, robust=True)
185+
if med is None or np.any(np.isnan(med)):
186+
if med is not None:
187+
logger.warning("Riemannian geometric median calculation resulted in "
188+
"NaNs. Using standard geometric median as fallback.")
189+
logger.info("Calculating robust geometric median covariance...")
190+
med = geometric_median(U.reshape(C * C, -1).T)
179191
if np.any(np.isnan(med)):
180-
if U_reshaped.shape[0] == 1:
181-
med = np.median(U_reshaped, axis=0)
182-
else:
183-
logger.warning("Geometric median calculation resulted in NaNs. Using standard median as fallback.")
184-
med = np.median(U_reshaped, axis=0)
192+
logger.warning("Geometric median calculation resulted in NaNs. "
193+
"Using standard median as fallback.")
194+
med = np.median(U, axis=-1)
185195

186-
# Reshape median back to matrix form
196+
# make sure median is reshaped back to matrix form
187197
M_robust = np.reshape(med, (C, C))
188198

189199
# Get the mixing matrix M (matrix square root of the robust covariance)
@@ -269,6 +279,7 @@ def asr_calibrate(X, srate, cutoff=None, blocksize=None, B=None, A=None,
269279
'carry': None, # Initial carry buffer (will be set in process)
270280
'last_R': None, # Initial reconstruction matrix (will be set in process)
271281
'last_trivial': True, # Initial trivial flag
282+
'useriemannian': useriemannian, # Riemannian ASR variant option
272283
}
273284

274285
return state

src/eegprep/utils/covariance.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
"""Tools for working with covariance matrices or stacks thereof."""
2+
3+
# Copyright (c) 2015-2025 Syntrogi Inc. dba Intheon.
4+
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.
22+
23+
24+
import logging
25+
26+
import numpy as np
27+
28+
logger = logging.getLogger(__name__)
29+
30+
__all__ = ['cov_mean', 'cov_logm', 'cov_expm', 'cov_powm', 'cov_sqrtm', 'cov_rsqrtm', 'cov_sqrtm2', 'cov_shrinkage']
31+
32+
33+
def diag_nd(M):
34+
"""Like np.diag, but in case of a ...,N, returns a ...,N,N array of diag matrices."""
35+
*dims, N = M.shape
36+
if dims:
37+
cat = np.concatenate([np.diag(d) for d in M.reshape((-1, N))])
38+
return np.reshape(cat, dims + [N, N])
39+
else:
40+
return np.diag(M)
41+
42+
43+
def cov_logm(C):
44+
"""Calculate the matrix logarithm of a covariance matrix or ...,N,N array."""
45+
D, V = np.linalg.eigh(C)
46+
return V @ diag_nd(np.log(D)) @ V.swapaxes(-2, -1)
47+
48+
49+
def cov_expm(C):
50+
"""Calculate the matrix exponent of a covariance matrix or ...,N,N array."""
51+
D, V = np.linalg.eigh(C)
52+
return V @ diag_nd(np.exp(D)) @ V.swapaxes(-2, -1)
53+
54+
55+
def cov_powm(C, exp):
56+
"""Calculate a matrix power of a covariance matrix or ...,N,N array."""
57+
D, V = np.linalg.eigh(C)
58+
return V @ diag_nd(D**exp) @ V.swapaxes(-2, -1)
59+
60+
61+
def cov_sqrtm(C):
62+
"""Calculate the matrix square root of a covariance matrix or ...,N,N array."""
63+
D, V = np.linalg.eigh(C)
64+
return V @ diag_nd(np.sqrt(D)) @ V.swapaxes(-2, -1)
65+
66+
67+
def cov_rsqrtm(C):
68+
"""Calculate the matrix reciprocal square root of a covariance matrix or ...,N,N array."""
69+
D, V = np.linalg.eigh(C)
70+
return V @ diag_nd(1./np.sqrt(D)) @ V.swapaxes(-2, -1)
71+
72+
73+
def cov_sqrtm2(C):
74+
"""Calculate the matrix square root, and its reciprocal, for a covariance matrix or ...,N,N array."""
75+
D, V = np.linalg.eigh(C)
76+
sqrtD = np.sqrt(D)
77+
return V @ diag_nd(sqrtD) @ V.swapaxes(-2, -1), V @ diag_nd(1./sqrtD) @ V.swapaxes(-2, -1)
78+
79+
80+
def cov_mean(X, *, weights=None, robust=False, iters=50, tol=1e-5, huber=0,
81+
nancheck=False, verbose=False):
82+
"""Calculate the (weighted) average of a set of covariance matrices on the
83+
manifold of SPD matrices, optionally robustly using the geometric median or
84+
Huber mean.
85+
86+
Args:
87+
X: a M,N,N array of covariance matrices
88+
weights: optionally a vector of sample weights (can be unnormalized)
89+
robust: whether to use a robust estimator
90+
iters: maximum number of iterations
91+
huber: huber threshold (delta parameter); can be set to
92+
* None: use regular least-squares solution
93+
* 0: use geometric / l1 median
94+
* >0: use a Huber mean with the given value as the threshold
95+
tol: tolerance for convergence check
96+
nancheck: check for NaNs
97+
verbose: generate verbose output (will print deviations in huber=None mode)
98+
99+
Returns:
100+
the N,N mean covariance matrix
101+
"""
102+
# This algorithm is based on:
103+
# [1] Ostresh et al., 1978, "On the Convergence of a Class of Iterative Methods for Solving the Weber Location Problem"
104+
# [2] Fletcher et al., 2004, "Principal Geodesic Analysis on Symmetric Spaces: Statistics of Diffusion Tensors"
105+
# [3] Fletcher et al. 2010, "The geometric median on Riemannian manifolds with application to robust atlas estimation"
106+
# [4] Barachant et al., 2014, "Multiclass Brain-Computer Interface Classification by Riemannian Geometry"
107+
weights = np.ones(len(X)) if weights is None else np.asarray(weights)
108+
scales = weights
109+
110+
mu = np.sum(X * weights[:, None, None], axis=0)/np.sum(weights)
111+
# step size and divergence check threshold
112+
step, thresh = 1.0, 1e20
113+
for i in range(iters):
114+
mu_sqrt, mu_rsqrt = cov_sqrtm2(mu)
115+
# linearize around mu (this would be the tangent space, but we omit
116+
# the pre/post-multiplied mu_sqrt terms since they cancel in both
117+
# the scale calculation and the exponential map)
118+
Xt = cov_logm(mu_rsqrt @ X @ mu_rsqrt)
119+
# geometric-median correction (downweight each pt by its riemannian
120+
# distance from mu, which we calc here after linearization)
121+
if robust:
122+
# deviations/errors per sample
123+
d = np.sqrt(np.sum(np.square(Xt), axis=(-2, -1)))
124+
# apply robust scale factor to provided sample weights
125+
if huber is None:
126+
scales = weights
127+
if verbose:
128+
logger.info(f"median deviations: {np.median(d)}")
129+
elif huber == 0:
130+
scales = weights / d
131+
else:
132+
w = np.where(d <= huber, 1, huber / d)
133+
scales = weights * w
134+
# get update Jacobian (np.average takes care of renormalization)
135+
J = np.sum(Xt * scales[:, None, None], axis=0)/np.sum(scales)
136+
# apply update on manifold
137+
mu = mu_sqrt @ cov_expm(step * J) @ mu_sqrt
138+
# convergence checks
139+
Jnorm = np.sqrt(np.sum(np.square(J)))
140+
if Jnorm < tol or step < tol:
141+
break
142+
h = step * Jnorm
143+
if h < thresh:
144+
# exponentially decaying learning rate
145+
step *= 0.95
146+
thresh = h
147+
else:
148+
# prevent blow-up
149+
step /= 2
150+
if nancheck and np.any(np.isnan(mu)):
151+
raise RuntimeError("NaNs occurred in cov_mean()")
152+
return mu
153+
154+
155+
def cov_shrinkage(cov, shrinkage=0, *, target='eye'):
156+
"""Regularize the given covariance matrix or stack of matrices using shrinkage.
157+
158+
Args:
159+
cov: the covariance matrix (N,N) or stack of matrices (...,N,N).
160+
shrinkage: degree of shrinkage, between 0 and 1
161+
target: target matrix to shrink towards; can be:
162+
'eye': the identity matrix (classic shrinkage; good for small values
163+
of shrinkage)
164+
'scaled-eye': the identity matrix, scaled to the average variance
165+
of the data (can be practical when shrinkage degree is large, since
166+
otherwise whitening will not have unit variance)
167+
'diag': the diagonal of the covariance matrix (diagonal shrinkage)
168+
169+
Returns:
170+
the regularized covariance matrix or stack of matrices.
171+
"""
172+
if not shrinkage:
173+
return cov # early exit
174+
175+
N = cov.shape[-1]
176+
177+
if target == 'eye':
178+
# create a stack of identity matrices matching cov's shape
179+
eye_target = np.zeros_like(cov)
180+
eye_target[..., range(N), range(N)] = 1
181+
elif target == 'scaled-eye':
182+
# calculate trace for each matrix in the stack (or single matrix)
183+
# trace_cov will have shape cov.shape[:-2] or be scalar if cov is 2D
184+
trace_cov = np.trace(cov, axis1=-2, axis2=-1)
185+
scale = trace_cov / N
186+
187+
# create a base stack of identity matrices
188+
eye_base = np.zeros_like(cov)
189+
eye_base[..., range(N), range(N)] = 1
190+
191+
# apply scaling
192+
scale_val = scale
193+
if cov.ndim > 2:
194+
scale_val = scale[..., np.newaxis, np.newaxis]
195+
eye_target = eye_base * scale_val
196+
elif target == 'diag':
197+
# get the main diagonal of each matrix in the stack
198+
main_diagonals = np.diagonal(cov, axis1=-2, axis2=-1)
199+
# create a stack of diagonal matrices
200+
eye_target = diag_nd(main_diagonals)
201+
else:
202+
raise ValueError(f'Unsupported shrinkage target: {target}')
203+
204+
cov_regu = shrinkage * eye_target + (1 - shrinkage) * cov
205+
return cov_regu

src/eegprep/utils/spatial.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55

66
# Helper function (vectorized version of MATLAB's interpMx)
7-
# Using a leading underscore as is common for internal helper functions in Python
87
def _interpMx(cosEE, order, tol):
98
"""
109
Compute the interpolation matrix for a set of point pairs (vectorized).

src/eegprep/utils/testing.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""Testing utilities."""
22

3-
import numpy as np
3+
import sys
44
import unittest
55

6-
__all__ = ['compare_eeg', 'DebuggableTestCase']
6+
import numpy as np
7+
8+
__all__ = ['compare_eeg', 'DebuggableTestCase', 'is_debug']
79

810

911
# default to True since the round-tripping through file can force data to
@@ -43,3 +45,7 @@ def debugTestCase(cls):
4345
loader = unittest.defaultTestLoader
4446
testSuite = loader.loadTestsFromTestCase(cls)
4547
testSuite.debug()
48+
49+
def is_debug():
50+
"""Determine whether Python is running in debug mode."""
51+
return getattr(sys, 'gettrace', None)() is not None

0 commit comments

Comments
 (0)