Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions src/spikeinterface/metrics/template/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,21 +284,21 @@ def sort_template_and_locations(template, channel_locations, depth_direction="y"
return template[:, sort_indices], channel_locations[sort_indices, :]


def fit_velocity(peak_times, channel_dist):
def fit_line_robust(x, y):
"""
Fit velocity from peak times and channel distances using robust Theilsen estimator.
Fit line using robust Theil-Sen estimator (median of pairwise slopes).
"""
# from scipy.stats import linregress
# slope, intercept, _, _, _ = linregress(peak_times, channel_dist)

from sklearn.linear_model import TheilSenRegressor

theil = TheilSenRegressor()
theil.fit(peak_times.reshape(-1, 1), channel_dist)
# Center data to improve numerical stability
X = (x - x.mean()).reshape(-1, 1)
y = y - y.mean()

theil = TheilSenRegressor(fit_intercept=False)
theil.fit(X, y)
slope = theil.coef_[0]
intercept = theil.intercept_
score = theil.score(peak_times.reshape(-1, 1), channel_dist)
return slope, intercept, score
score = theil.score(X, y)
return slope, score


def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs):
Expand Down Expand Up @@ -354,7 +354,8 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs)
channel_locations_above = channel_locations[channels_above]
peak_times_ms_above = np.argmin(template_above, 0) / sampling_frequency * 1000 - max_peak_time
distances_um_above = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_above])
velocity_above, _, score = fit_velocity(peak_times_ms_above, distances_um_above)
inv_velocity_above, score = fit_line_robust(distances_um_above, peak_times_ms_above)
velocity_above = 1 / inv_velocity_above
if score < min_r2:
velocity_above = np.nan

Expand All @@ -367,7 +368,8 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs)
channel_locations_below = channel_locations[channels_below]
peak_times_ms_below = np.argmin(template_below, 0) / sampling_frequency * 1000 - max_peak_time
distances_um_below = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_below])
velocity_below, _, score = fit_velocity(peak_times_ms_below, distances_um_below)
inv_velocity_below, score = fit_line_robust(distances_um_below, peak_times_ms_below)
velocity_below = 1 / inv_velocity_below
if score < min_r2:
velocity_below = np.nan

Expand Down