diff --git a/src/spikeinterface/metrics/template/metrics.py b/src/spikeinterface/metrics/template/metrics.py index a1af1de348..ba3f4a27b8 100644 --- a/src/spikeinterface/metrics/template/metrics.py +++ b/src/spikeinterface/metrics/template/metrics.py @@ -284,21 +284,21 @@ def sort_template_and_locations(template, channel_locations, depth_direction="y" return template[:, sort_indices], channel_locations[sort_indices, :] -def fit_velocity(peak_times, channel_dist): +def fit_line_robust(x, y): """ - Fit velocity from peak times and channel distances using robust Theilsen estimator. + Fit line using robust Theil-Sen estimator (median of pairwise slopes). """ - # from scipy.stats import linregress - # slope, intercept, _, _, _ = linregress(peak_times, channel_dist) - from sklearn.linear_model import TheilSenRegressor - theil = TheilSenRegressor() - theil.fit(peak_times.reshape(-1, 1), channel_dist) + # Center data to improve numerical stability + X = (x - x.mean()).reshape(-1, 1) + y = y - y.mean() + + theil = TheilSenRegressor(fit_intercept=False) + theil.fit(X, y) slope = theil.coef_[0] - intercept = theil.intercept_ - score = theil.score(peak_times.reshape(-1, 1), channel_dist) - return slope, intercept, score + score = theil.score(X, y) + return slope, score def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs): @@ -354,7 +354,8 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs) channel_locations_above = channel_locations[channels_above] peak_times_ms_above = np.argmin(template_above, 0) / sampling_frequency * 1000 - max_peak_time distances_um_above = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_above]) - velocity_above, _, score = fit_velocity(peak_times_ms_above, distances_um_above) + inv_velocity_above, score = fit_line_robust(distances_um_above, peak_times_ms_above) + velocity_above = 1 / inv_velocity_above if score < min_r2: velocity_above = np.nan @@ -367,7 +368,8 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs) channel_locations_below = channel_locations[channels_below] peak_times_ms_below = np.argmin(template_below, 0) / sampling_frequency * 1000 - max_peak_time distances_um_below = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_below]) - velocity_below, _, score = fit_velocity(peak_times_ms_below, distances_um_below) + inv_velocity_below, score = fit_line_robust(distances_um_below, peak_times_ms_below) + velocity_below = 1 / inv_velocity_below if score < min_r2: velocity_below = np.nan