From 365fe04ff8e9f54c24ef3a1383c339c04a31e34d Mon Sep 17 00:00:00 2001 From: Erick Cobos Date: Thu, 22 Jan 2026 18:16:45 +0100 Subject: [PATCH 1/2] change direction of velocity fit in template metrics --- .../metrics/template/metrics.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/metrics/template/metrics.py b/src/spikeinterface/metrics/template/metrics.py index a1af1de348..fd39cb6a66 100644 --- a/src/spikeinterface/metrics/template/metrics.py +++ b/src/spikeinterface/metrics/template/metrics.py @@ -283,22 +283,21 @@ def sort_template_and_locations(template, channel_locations, depth_direction="y" sort_indices = np.argsort(channel_locations[:, depth_dim]) return template[:, sort_indices], channel_locations[sort_indices, :] - -def fit_velocity(peak_times, channel_dist): +def fit_line_robust(x, y): """ - Fit velocity from peak times and channel distances using robust Theilsen estimator. + Fit line using robust Theil-Sen estimator (median of pairwise slopes). """ - # from scipy.stats import linregress - # slope, intercept, _, _, _ = linregress(peak_times, channel_dist) - from sklearn.linear_model import TheilSenRegressor - theil = TheilSenRegressor() - theil.fit(peak_times.reshape(-1, 1), channel_dist) + # Center data to improve numerical stability + X = (x - x.mean()).reshape(-1, 1) + y = y - y.mean() + + theil = TheilSenRegressor(fit_intercept=False) + theil.fit(X, y) slope = theil.coef_[0] - intercept = theil.intercept_ - score = theil.score(peak_times.reshape(-1, 1), channel_dist) - return slope, intercept, score + score = theil.score(X, y) + return slope, score def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs): @@ -354,7 +353,8 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs) channel_locations_above = channel_locations[channels_above] peak_times_ms_above = np.argmin(template_above, 0) / sampling_frequency * 1000 - max_peak_time distances_um_above = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_above]) - velocity_above, _, score = fit_velocity(peak_times_ms_above, distances_um_above) + inv_velocity_above, score = fit_line_robust(distances_um_above, peak_times_ms_above) + velocity_above = 1 / inv_velocity_above if score < min_r2: velocity_above = np.nan @@ -367,7 +367,8 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs) channel_locations_below = channel_locations[channels_below] peak_times_ms_below = np.argmin(template_below, 0) / sampling_frequency * 1000 - max_peak_time distances_um_below = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_below]) - velocity_below, _, score = fit_velocity(peak_times_ms_below, distances_um_below) + inv_velocity_below, score = fit_line_robust(distances_um_below, peak_times_ms_below) + velocity_below = 1 / inv_velocity_below if score < min_r2: velocity_below = np.nan From 362a87b98969a99a2cd585c7383a39b04e60e923 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 24 Jan 2026 16:35:42 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/metrics/template/metrics.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/metrics/template/metrics.py b/src/spikeinterface/metrics/template/metrics.py index fd39cb6a66..ba3f4a27b8 100644 --- a/src/spikeinterface/metrics/template/metrics.py +++ b/src/spikeinterface/metrics/template/metrics.py @@ -283,6 +283,7 @@ def sort_template_and_locations(template, channel_locations, depth_direction="y" sort_indices = np.argsort(channel_locations[:, depth_dim]) return template[:, sort_indices], channel_locations[sort_indices, :] + def fit_line_robust(x, y): """ Fit line using robust Theil-Sen estimator (median of pairwise slopes).