From 3dc57290dbde0aeaa5048f2301ee75015a93fe26 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 29 Dec 2025 15:43:44 +0100 Subject: [PATCH 01/40] Test IBL extractors tests failing for PI update --- src/spikeinterface/extractors/tests/test_iblextractors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index 972a8e7bb0..56d01e38cf 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -76,8 +76,8 @@ def test_offsets(self): def test_probe_representation(self): probe = self.recording.get_probe() - expected_probe_representation = "Probe - 384ch - 1shanks" - assert repr(probe) == expected_probe_representation + expected_probe_representation = "Probe - 384ch" + assert expected_probe_representation in repr(probe) def test_property_keys(self): expected_property_keys = [ From 7279b6753f30aff0bfe485b8ee884e56b3068822 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 7 Jan 2026 16:26:47 +0100 Subject: [PATCH 02/40] wip --- .../core/analyzer_extension_core.py | 34 +++++++- src/spikeinterface/core/basesorting.py | 20 +++++ src/spikeinterface/core/node_pipeline.py | 11 ++- src/spikeinterface/core/sorting_tools.py | 77 +++++++++++++++++ .../core/tests/test_basesorting.py | 64 ++++++++++++-- .../metrics/quality/misc_metrics.py | 85 +++++++++++++++---- .../tests/test_interpolate_bad_channels.py | 2 +- 7 files changed, 268 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 74ef52e258..5e46f20d22 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -19,6 +19,7 @@ from .template import Templates from .sorting_tools import random_spikes_selection from .job_tools import fix_job_kwargs, split_job_kwargs +from .node_pipeline import base_period_dtype class ComputeRandomSpikes(AnalyzerExtension): @@ -1331,6 +1332,21 @@ class BaseSpikeVectorExtension(AnalyzerExtension): need_backward_compatibility_on_load = False nodepipeline_variables = [] # to be defined in subclass + def __init__(self, sorting_analyzer): + super().__init__(sorting_analyzer) + self._segment_slices = None + + @property + def segment_slices(self): + if self._segment_slices is None: + segment_slices = [] + spikes = self.sorting_analyzer.sorting.to_spike_vector() + for segment_index in range(self.sorting_analyzer.get_num_segments()): + i0, i1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) + segment_slices.append(slice(i0, i1)) + self._segment_slices = segment_slices + return self._segment_slices + def _set_params(self, **kwargs): params = kwargs.copy() return params @@ -1369,7 +1385,7 @@ def _run(self, verbose=False, **job_kwargs): for d, name in zip(data, data_names): self.data[name] = d - def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, copy=True): + def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, periods=None, copy=True): """ Return extension data. If the extension computes more than one `nodepipeline_variables`, the `return_data_name` is used to specify which one to return. @@ -1383,13 +1399,15 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, return_data_name : str | None, default: None The name of the data to return. If None and multiple `nodepipeline_variables` are computed, the first one is returned. + periods : array of unit_period dtype, default: None + Optional periods (segment_index, start_sample_index, end_sample_index, unit_index) to slice output data copy : bool, default: True Whether to return a copy of the data (only for outputs="numpy") Returns ------- numpy.ndarray | dict - The + The requested data in numpy or by unit format. """ from spikeinterface.core.sorting_tools import spike_vector_to_indices @@ -1404,6 +1422,18 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, ), f"return_data_name {return_data_name} not in nodepipeline_variables {self.nodepipeline_variables}" all_data = self.data[return_data_name] + if periods is not None: + # TODO: slice this properly with unit_indices + required = np.dtype(base_period_dtype).names + if not required.issubset(periods.dtype.names): + raise ValueError(f"Period must have the following fields: {required}") + # slice data according to period + segment_slices = self.segment_slices + all_data_segment = all_data[segment_slices[periods["segment_index"]]] + start = periods["start_sample_index"] + end = periods["end_sample_index"] + all_data = all_data_segment[start:end] + if outputs == "numpy": if copy: return all_data.copy() # return a copy to avoid modification diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 98159fb646..b6440f8e2b 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -626,6 +626,26 @@ def time_slice(self, start_time: float | None, end_time: float | None) -> BaseSo return self.frame_slice(start_frame=start_frame, end_frame=end_frame) + def select_periods(self, periods): + """ + Returns a new sorting object, restricted to the given periods of dtype unit_period_dtype. + + Parameters + ---------- + periods : numpy.array of unit_period_dtype + Period (segment_index, start_sample_index, end_sample_index, unit_index) + on which to restrict the sorting. + + Returns + ------- + BaseSorting + A new sorting object with only samples between start_sample_index and end_sample_index + for the given segment_index. + """ + from spikeinterface.core.sorting_tools import select_sorting_periods + + return select_sorting_periods(self, periods) + def split_by(self, property="group", outputs="dict"): """ Splits object based on a certain property (e.g. "group") diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 71654a67b4..f6bf3cb31f 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -22,11 +22,20 @@ ("segment_index", "int64"), ] - spike_peak_dtype = base_peak_dtype + [ ("unit_index", "int64"), ] +base_period_dtype = [ + ("start_sample_index", "int64"), + ("end_sample_index", "int64"), + ("segment_index", "int64"), +] + +unit_period_dtype = base_period_dtype + [ + ("unit_index", "int64"), +] + class PipelineNode: diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 90c7e18a99..9a9a3670ef 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -228,6 +228,83 @@ def random_spikes_selection( return random_spikes_indices +def select_sorting_periods_mask(sorting: BaseSorting, periods): + """ + Returns a boolean mask for the spikes in the sorting object, restricted to the given periods of dtype unit_period_dtype. + + Parameters + ---------- + sorting : BaseSorting + The sorting object. + periods : numpy.array of unit_period_dtype + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to restrict the sorting. + + Returns + ------- + numpy.array + A boolean mask of the spikes in the sorting object, with True for spikes within the specified periods. + """ + spike_vector = sorting.to_spike_vector() + spike_vector_list = sorting.to_spike_vector(concatenated=False) + keep_mask = np.zeros(len(spike_vector), dtype=bool) + all_global_indices = spike_vector_to_indices(spike_vector_list, unit_ids=sorting.unit_ids, absolute_index=True) + for segment_index in range(sorting.get_num_segments()): + global_indices_segment = all_global_indices[segment_index] + # filter periods by segment + periods_in_segment = periods[periods["segment_index"] == segment_index] + for unit_index, unit_id in enumerate(sorting.unit_ids): + # filter by unit index + periods_for_unit = periods_in_segment[periods_in_segment["unit_index"] == unit_index] + global_indices = global_indices_segment[unit_id] + spiketrains = spike_vector[global_indices]["sample_index"] + if len(periods_for_unit) > 0: + for period in periods_for_unit: + mask = (spiketrains >= period["start_sample_index"]) & (spiketrains < period["end_sample_index"]) + keep_mask[global_indices[mask]] = True + return keep_mask + + +def select_sorting_periods(sorting: BaseSorting, periods): + """ + Returns a new sorting object, restricted to the given periods of dtype unit_period_dtype. + + Parameters + ---------- + S + periods : numpy.array of unit_period_dtype + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to restrict the sorting. + + Returns + ------- + BaseSorting + A new sorting object with only samples between start_sample_index and end_sample_index + for the given segment_index. + """ + from spikeinterface.core.numpyextractors import NumpySorting + from spikeinterface.core.node_pipeline import unit_period_dtype + + if periods is not None: + if not isinstance(periods, np.ndarray): + periods = np.array([periods], dtype=unit_period_dtype) + required = set(np.dtype(unit_period_dtype).names) + if not required.issubset(periods.dtype.names): + raise ValueError(f"Period must have the following fields: {required}") + + spike_vector = sorting.to_spike_vector() + keep_mask = select_sorting_periods_mask(sorting, periods) + sliced_spike_vector = spike_vector[keep_mask] + + sorting = NumpySorting( + sliced_spike_vector, sampling_frequency=sorting.sampling_frequency, unit_ids=sorting.unit_ids + ) + sorting.copy_metadata(sorting) + return sorting + else: + return sorting + + ### MERGING ZONE ### def apply_merges_to_sorting( sorting: BaseSorting, diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 54befd40ec..ada35a57e9 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -3,9 +3,7 @@ but check only for BaseRecording general methods. """ -import shutil -from pathlib import Path - +import time import numpy as np import pytest from numpy.testing import assert_raises @@ -17,15 +15,15 @@ SharedMemorySorting, NpzFolderSorting, NumpyFolderSorting, + generate_ground_truth_recording, + generate_sorting, create_sorting_npz, generate_sorting, load, ) from spikeinterface.core.base import BaseExtractor from spikeinterface.core.testing import check_sorted_arrays_equal, check_sortings_equal -from spikeinterface.core.generate import generate_sorting - -from spikeinterface.core import generate_recording, generate_ground_truth_recording +from spikeinterface.core.node_pipeline import unit_period_dtype def test_BaseSorting(create_cache_folder): @@ -226,7 +224,61 @@ def test_time_slice(): ) +def test_select_periods(): + sampling_frequency = 10_000.0 + duration = 1_000 + num_samples = int(sampling_frequency * duration) + num_units = 1000 + sorting = generate_sorting( + durations=[duration, duration], sampling_frequency=sampling_frequency, num_units=num_units + ) + + rng = np.random.default_rng() + + # number of random periods + n_periods = 10_000 + # generate random periods + segment_indices = rng.integers(0, sorting.get_num_segments(), n_periods) + start_samples = rng.integers(0, num_samples, n_periods) + durations = rng.integers(100, 100_000, n_periods) + end_samples = start_samples + durations + valid_periods = end_samples < num_samples + segment_indices = segment_indices[valid_periods] + start_samples = start_samples[valid_periods] + end_samples = end_samples[valid_periods] + unit_index = rng.integers(0, num_units - 1, len(segment_indices)) + + periods = np.zeros(len(segment_indices), dtype=unit_period_dtype) + periods["segment_index"] = segment_indices + periods["start_sample_index"] = start_samples + periods["end_sample_index"] = end_samples + periods["unit_index"] = unit_index + + t_start = time.perf_counter() + sliced_sorting = sorting.select_periods(periods=periods) + t_stop = time.perf_counter() + elapsed = t_stop - t_start + print(f"select_periods took {elapsed:.2f} seconds for {len(periods)} periods") + + # Check that all spikes in the sliced sorting are within the periods + for segment_index in range(sorting.get_num_segments()): + for unit_index, unit_id in enumerate(sorting.unit_ids): + spiketrain = sorting.get_unit_spike_train(segment_index=segment_index, unit_id=unit_id) + spiketrain_sliced = sliced_sorting.get_unit_spike_train(segment_index=segment_index, unit_id=unit_id) + spikes_in_periods = np.array([], dtype=spiketrain.dtype) + periods_in_segment = periods[periods["segment_index"] == segment_index] + periods_for_unit = periods_in_segment[periods_in_segment["unit_index"] == unit_index] + for period in periods_for_unit: + start_sample = period["start_sample_index"] + end_sample = period["end_sample_index"] + spikes_in_period = spiketrain[(spiketrain >= start_sample) & (spiketrain < end_sample)] + spikes_in_periods = np.concatenate((spikes_in_periods, spikes_in_period)) + if not len(spikes_in_periods) == len(spiketrain_sliced): + print(f"Mismatch in number of spikes!: {len(spikes_in_periods)} vs {len(spiketrain_sliced)}") + + if __name__ == "__main__": test_BaseSorting() test_npy_sorting() test_empty_sorting() + test_select_periods() diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index c6b07da52e..028b2eeca5 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -19,12 +19,13 @@ from spikeinterface.core.analyzer_extension_core import BaseMetric from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from spikeinterface.postprocessing import correlogram_for_one_segment -from spikeinterface.core import SortingAnalyzer, get_noise_levels +from spikeinterface.core import SortingAnalyzer, get_noise_levels, select_segment_sorting from spikeinterface.core.template_tools import ( get_template_extremum_channel, get_template_extremum_amplitude, get_dense_templates_array, ) +from spikeinterface.core.node_pipeline import base_period_dtype from ..spiketrain.metrics import NumSpikes, FiringRate @@ -35,7 +36,9 @@ HAVE_NUMBA = False -def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0): +def compute_presence_ratios( + sorting_analyzer, unit_ids=None, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, periods=None +): """ Calculate the presence ratio, the fraction of time the unit is firing above a certain threshold. @@ -51,6 +54,9 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 mean_fr_ratio_thresh : float, default: 0 The unit is considered active in a bin if its firing rate during that bin. is strictly above `mean_fr_ratio_thresh` times its mean firing rate throughout the recording. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -63,6 +69,7 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 To do so, spike trains across segments are concatenated to mimic a continuous segment. """ sorting = sorting_analyzer.sorting + sorting = sorting.select_period(periods=periods) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() @@ -182,7 +189,7 @@ class SNR(BaseMetric): depend_on = ["noise_levels", "templates"] -def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5, min_isi_ms=0): +def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5, min_isi_ms=0, periods=None): """ Calculate Inter-Spike Interval (ISI) violations. @@ -204,6 +211,9 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 Minimum possible inter-spike interval, in ms. This is the artificial refractory period enforced. by the data acquisition system or post-processing algorithms. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -235,6 +245,7 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 res = namedtuple("isi_violation", ["isi_violations_ratio", "isi_violations_count"]) sorting = sorting_analyzer.sorting + sorting = sorting.select_period(sorting, periods=periods) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() @@ -280,7 +291,7 @@ class ISIViolation(BaseMetric): def compute_refrac_period_violations( - sorting_analyzer, unit_ids=None, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0 + sorting_analyzer, unit_ids=None, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0, periods=None ): """ Calculate the number of refractory period violations. @@ -300,6 +311,9 @@ def compute_refrac_period_violations( censored_period_ms : float, default: 0.0 The period (in ms) where no 2 spikes can occur (because they are not detected, or because they were removed by another mean). + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -332,6 +346,8 @@ def compute_refrac_period_violations( return None sorting = sorting_analyzer.sorting + sorting = sorting.select_period(periods=periods) + fs = sorting_analyzer.sampling_frequency num_units = len(sorting_analyzer.unit_ids) num_segments = sorting_analyzer.get_num_segments() @@ -392,6 +408,7 @@ def compute_sliding_rp_violations( exclude_ref_period_below_ms=0.5, max_ref_period_ms=10, contamination_values=None, + periods=None, ): """ Compute sliding refractory period violations, a metric developed by IBL which computes @@ -417,6 +434,9 @@ def compute_sliding_rp_violations( Maximum refractory period to test in ms. contamination_values : 1d array or None, default: None The contamination values to test, If None, it is set to np.arange(0.5, 35, 0.5). + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -431,6 +451,8 @@ def compute_sliding_rp_violations( """ duration = sorting_analyzer.get_total_duration() sorting = sorting_analyzer.sorting + sorting = sorting.select_period(periods=periods) + if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() @@ -486,7 +508,7 @@ class SlidingRPViolation(BaseMetric): } -def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None): +def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None, periods=None): """ Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of spikes at the exact same sample index, with synchrony sizes 2, 4 and 8. @@ -504,6 +526,9 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N ------- sync_spike_{X} : dict The synchrony metric for synchrony size X. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. References ---------- @@ -520,6 +545,7 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes]) sorting = sorting_analyzer.sorting + sorting = sorting.select_period(periods=periods) if unit_ids is None: unit_ids = sorting.unit_ids @@ -556,7 +582,7 @@ class Synchrony(BaseMetric): } -def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percentiles=(5, 95)): +def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percentiles=(5, 95), periods=None): """ Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution computed in non-overlapping time bins. @@ -571,6 +597,9 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent The size of the bin in seconds. percentiles : tuple, default: (5, 95) The percentiles to compute. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -584,6 +613,8 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent sampling_frequency = sorting_analyzer.sampling_frequency bin_size_samples = int(bin_size_s * sampling_frequency) sorting = sorting_analyzer.sorting + sorting = sorting.select_period(periods=periods) + if unit_ids is None: unit_ids = sorting.unit_ids @@ -635,6 +666,7 @@ def compute_amplitude_cv_metrics( percentiles=(5, 95), min_num_bins=10, amplitude_extension="spike_amplitudes", + periods=None, ): """ Calculate coefficient of variation of spike amplitudes within defined temporal bins. @@ -658,6 +690,8 @@ def compute_amplitude_cv_metrics( the median and range are set to NaN. amplitude_extension : str, default: "spike_amplitudes" The name of the extension to load the amplitudes from. "spike_amplitudes" or "amplitude_scalings". + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) Returns ------- @@ -683,7 +717,7 @@ def compute_amplitude_cv_metrics( if unit_ids is None: unit_ids = sorting.unit_ids - amps = sorting_analyzer.get_extension(amplitude_extension).get_data() + amps = sorting_analyzer.get_extension(amplitude_extension).get_data(period=period) # precompute segment slice segment_slices = [] @@ -752,6 +786,7 @@ def compute_amplitude_cutoffs( num_histogram_bins=500, histogram_smoothing_value=3, amplitudes_bins_min_ratio=5, + periods=None, ): """ Calculate approximate fraction of spikes missing from a distribution of amplitudes. @@ -770,6 +805,9 @@ def compute_amplitude_cutoffs( The minimum ratio between number of amplitudes for a unit and the number of bins. If the ratio is less than this threshold, the amplitude_cutoff for the unit is set to NaN. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -805,7 +843,7 @@ def compute_amplitude_cutoffs( invert_amplitudes = True extension = sorting_analyzer.get_extension("amplitude_scalings") - amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True) + amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, period=period) for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] @@ -837,7 +875,7 @@ class AmplitudeCutoff(BaseMetric): depend_on = ["spike_amplitudes|amplitude_scalings"] -def compute_amplitude_medians(sorting_analyzer, unit_ids=None): +def compute_amplitude_medians(sorting_analyzer, unit_ids=None, periods=None): """ Compute median of the amplitude distributions (in absolute value). @@ -847,6 +885,9 @@ def compute_amplitude_medians(sorting_analyzer, unit_ids=None): A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the amplitude medians. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -865,7 +906,7 @@ def compute_amplitude_medians(sorting_analyzer, unit_ids=None): all_amplitude_medians = {} amplitude_extension = sorting_analyzer.get_extension("spike_amplitudes") - amplitudes_by_units = amplitude_extension.get_data(outputs="by_unit", concatenated=True) + amplitudes_by_units = amplitude_extension.get_data(outputs="by_unit", concatenated=True, period=period) for unit_id in unit_ids: all_amplitude_medians[unit_id] = np.median(amplitudes_by_units[unit_id]) @@ -882,7 +923,9 @@ class AmplitudeMedian(BaseMetric): depend_on = ["spike_amplitudes"] -def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, low_quantile=0.1, n_bins=100): +def compute_noise_cutoffs( + sorting_analyzer, unit_ids=None, high_quantile=0.25, low_quantile=0.1, n_bins=100, periods=None +): """ A metric to determine if a unit's amplitude distribution is cut off as it approaches zero, without assuming a Gaussian distribution. @@ -906,6 +949,9 @@ def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, l Quantile of the amplitude range below which values are treated as "low" (e.g. 0.1 = lower 10%), the test region. n_bins: int, default: 100 The number of bins to use to compute the amplitude histogram. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -934,7 +980,7 @@ def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, l invert_amplitudes = True extension = sorting_analyzer.get_extension("amplitude_scalings") - amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True) + amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, period=period) for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] @@ -972,6 +1018,7 @@ def compute_drift_metrics( min_fraction_valid_intervals=0.5, min_num_bins=2, return_positions=False, + periods=None, ): """ Compute drifts metrics using estimated spike locations. @@ -1006,6 +1053,9 @@ def compute_drift_metrics( min_num_bins : int, default: 2 Minimum number of bins required to return a valid metric value. In case there are less bins, the metric values are set to NaN. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. return_positions : bool, default: False If True, median positions are returned (for debugging). @@ -1032,8 +1082,7 @@ def compute_drift_metrics( unit_ids = sorting.unit_ids spike_locations_ext = sorting_analyzer.get_extension("spike_locations") - spike_locations = spike_locations_ext.get_data() - # spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit") + spike_locations = spike_locations_ext.get_data(period=period) spikes = sorting.to_spike_vector() spike_locations_by_unit = {} for unit_id in unit_ids: @@ -1145,12 +1194,14 @@ class Drift(BaseMetric): depend_on = ["spike_locations"] +# TODO def compute_sd_ratio( sorting_analyzer: SortingAnalyzer, unit_ids=None, censored_period_ms: float = 4.0, correct_for_drift: bool = True, correct_for_template_itself: bool = True, + periods=None, **kwargs, ): """ @@ -1173,6 +1224,9 @@ def compute_sd_ratio( correct_for_template_itself : bool, default: True If true, will take into account that the template itself impacts the standard deviation of the noise, and will make a rough estimation of what that impact is (and remove it). + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. **kwargs : dict, default: {} Keyword arguments for computing spike amplitudes and extremum channel. @@ -1189,6 +1243,7 @@ def compute_sd_ratio( job_kwargs = fix_job_kwargs(job_kwargs) sorting = sorting_analyzer.sorting + sorting = sorting.select_period(periods=periods) censored_period = int(round(censored_period_ms * 1e-3 * sorting_analyzer.sampling_frequency)) if unit_ids is None: @@ -1201,7 +1256,7 @@ def compute_sd_ratio( ) return {unit_id: np.nan for unit_id in unit_ids} - spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() + spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data(period=period) if not HAVE_NUMBA: warnings.warn( diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index ab7ae9e7b5..75e41620f4 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -130,7 +130,7 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan recording._properties["contact_vector"][idx][1] = x[idx] # generate random bad channel locations - bad_channel_indexes = rng.choice(num_channels, rng.randint(1, int(num_channels / 5)), replace=False) + bad_channel_indexes = rng.choice(num_channels, rng.integers(1, int(num_channels / 5)), replace=False) bad_channel_ids = recording.channel_ids[bad_channel_indexes] # Run SI and IBL interpolation and check against eachother From 1962f212f2dcd68a56275d123a554b7757558143 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 7 Jan 2026 17:22:37 +0100 Subject: [PATCH 03/40] Fix test for base sorting and propagate to basevector extension --- .../core/analyzer_extension_core.py | 18 ++++++------------ .../core/tests/test_basesorting.py | 19 ++++++++++++------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 5e46f20d22..9b93807b8c 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -17,9 +17,8 @@ from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_with_accumulator from .recording_tools import get_noise_levels from .template import Templates -from .sorting_tools import random_spikes_selection +from .sorting_tools import random_spikes_selection, select_sorting_periods_mask from .job_tools import fix_job_kwargs, split_job_kwargs -from .node_pipeline import base_period_dtype class ComputeRandomSpikes(AnalyzerExtension): @@ -1423,16 +1422,11 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, all_data = self.data[return_data_name] if periods is not None: - # TODO: slice this properly with unit_indices - required = np.dtype(base_period_dtype).names - if not required.issubset(periods.dtype.names): - raise ValueError(f"Period must have the following fields: {required}") - # slice data according to period - segment_slices = self.segment_slices - all_data_segment = all_data[segment_slices[periods["segment_index"]]] - start = periods["start_sample_index"] - end = periods["end_sample_index"] - all_data = all_data_segment[start:end] + keep_mask = select_sorting_periods_mask( + self.sorting_analyzer.sorting, + periods, + ) + all_data = all_data[keep_mask] if outputs == "numpy": if copy: diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index ada35a57e9..18f632ed34 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -253,6 +253,7 @@ def test_select_periods(): periods["start_sample_index"] = start_samples periods["end_sample_index"] = end_samples periods["unit_index"] = unit_index + periods = np.sort(periods, order=["segment_index", "start_sample_index"]) t_start = time.perf_counter() sliced_sorting = sorting.select_periods(periods=periods) @@ -262,19 +263,23 @@ def test_select_periods(): # Check that all spikes in the sliced sorting are within the periods for segment_index in range(sorting.get_num_segments()): + periods_in_segment = periods[periods["segment_index"] == segment_index] for unit_index, unit_id in enumerate(sorting.unit_ids): spiketrain = sorting.get_unit_spike_train(segment_index=segment_index, unit_id=unit_id) - spiketrain_sliced = sliced_sorting.get_unit_spike_train(segment_index=segment_index, unit_id=unit_id) - spikes_in_periods = np.array([], dtype=spiketrain.dtype) - periods_in_segment = periods[periods["segment_index"] == segment_index] + periods_for_unit = periods_in_segment[periods_in_segment["unit_index"] == unit_index] + spiketrain_in_periods = [] for period in periods_for_unit: start_sample = period["start_sample_index"] end_sample = period["end_sample_index"] - spikes_in_period = spiketrain[(spiketrain >= start_sample) & (spiketrain < end_sample)] - spikes_in_periods = np.concatenate((spikes_in_periods, spikes_in_period)) - if not len(spikes_in_periods) == len(spiketrain_sliced): - print(f"Mismatch in number of spikes!: {len(spikes_in_periods)} vs {len(spiketrain_sliced)}") + spiketrain_in_periods.append(spiketrain[(spiketrain >= start_sample) & (spiketrain < end_sample)]) + if len(spiketrain_in_periods) == 0: + spiketrain_in_periods = np.array([], dtype=spiketrain.dtype) + else: + spiketrain_in_periods = np.unique(np.concatenate(spiketrain_in_periods)) + + spiketrain_sliced = sliced_sorting.get_unit_spike_train(segment_index=segment_index, unit_id=unit_id) + assert len(spiketrain_in_periods) == len(spiketrain_sliced) if __name__ == "__main__": From 528c82b7951db9d509030b7fc10e3796fb69347b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 8 Jan 2026 08:33:58 +0100 Subject: [PATCH 04/40] Fix tests in quailty metrics --- .../metrics/quality/misc_metrics.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 028b2eeca5..4a7ef04554 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -69,7 +69,7 @@ def compute_presence_ratios( To do so, spike trains across segments are concatenated to mimic a continuous segment. """ sorting = sorting_analyzer.sorting - sorting = sorting.select_period(periods=periods) + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() @@ -245,7 +245,7 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 res = namedtuple("isi_violation", ["isi_violations_ratio", "isi_violations_count"]) sorting = sorting_analyzer.sorting - sorting = sorting.select_period(sorting, periods=periods) + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() @@ -346,7 +346,7 @@ def compute_refrac_period_violations( return None sorting = sorting_analyzer.sorting - sorting = sorting.select_period(periods=periods) + sorting = sorting.select_periods(periods=periods) fs = sorting_analyzer.sampling_frequency num_units = len(sorting_analyzer.unit_ids) @@ -451,7 +451,7 @@ def compute_sliding_rp_violations( """ duration = sorting_analyzer.get_total_duration() sorting = sorting_analyzer.sorting - sorting = sorting.select_period(periods=periods) + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids @@ -545,7 +545,7 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes]) sorting = sorting_analyzer.sorting - sorting = sorting.select_period(periods=periods) + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting.unit_ids @@ -613,7 +613,7 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent sampling_frequency = sorting_analyzer.sampling_frequency bin_size_samples = int(bin_size_s * sampling_frequency) sorting = sorting_analyzer.sorting - sorting = sorting.select_period(periods=periods) + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting.unit_ids @@ -717,7 +717,7 @@ def compute_amplitude_cv_metrics( if unit_ids is None: unit_ids = sorting.unit_ids - amps = sorting_analyzer.get_extension(amplitude_extension).get_data(period=period) + amps = sorting_analyzer.get_extension(amplitude_extension).get_data(periods=periods) # precompute segment slice segment_slices = [] @@ -843,7 +843,7 @@ def compute_amplitude_cutoffs( invert_amplitudes = True extension = sorting_analyzer.get_extension("amplitude_scalings") - amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, period=period) + amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, periods=periods) for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] @@ -906,7 +906,7 @@ def compute_amplitude_medians(sorting_analyzer, unit_ids=None, periods=None): all_amplitude_medians = {} amplitude_extension = sorting_analyzer.get_extension("spike_amplitudes") - amplitudes_by_units = amplitude_extension.get_data(outputs="by_unit", concatenated=True, period=period) + amplitudes_by_units = amplitude_extension.get_data(outputs="by_unit", concatenated=True, periods=periods) for unit_id in unit_ids: all_amplitude_medians[unit_id] = np.median(amplitudes_by_units[unit_id]) @@ -980,7 +980,7 @@ def compute_noise_cutoffs( invert_amplitudes = True extension = sorting_analyzer.get_extension("amplitude_scalings") - amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, period=period) + amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, periods=periods) for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] @@ -1082,7 +1082,7 @@ def compute_drift_metrics( unit_ids = sorting.unit_ids spike_locations_ext = sorting_analyzer.get_extension("spike_locations") - spike_locations = spike_locations_ext.get_data(period=period) + spike_locations = spike_locations_ext.get_data(periods=periods) spikes = sorting.to_spike_vector() spike_locations_by_unit = {} for unit_id in unit_ids: @@ -1243,7 +1243,7 @@ def compute_sd_ratio( job_kwargs = fix_job_kwargs(job_kwargs) sorting = sorting_analyzer.sorting - sorting = sorting.select_period(periods=periods) + sorting = sorting.select_periods(periods=periods) censored_period = int(round(censored_period_ms * 1e-3 * sorting_analyzer.sampling_frequency)) if unit_ids is None: @@ -1256,7 +1256,7 @@ def compute_sd_ratio( ) return {unit_id: np.nan for unit_id in unit_ids} - spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data(period=period) + spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data(periods=periods) if not HAVE_NUMBA: warnings.warn( From 775dda710adc4c9b4a7eddb7e5dea99d8d9df884 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 8 Jan 2026 12:46:23 +0100 Subject: [PATCH 05/40] Fix retrieval of spikevector features --- src/spikeinterface/core/analyzer_extension_core.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 9b93807b8c..804418a2ff 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -1421,6 +1421,7 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, ), f"return_data_name {return_data_name} not in nodepipeline_variables {self.nodepipeline_variables}" all_data = self.data[return_data_name] + keep_mask = None if periods is not None: keep_mask = select_sorting_periods_mask( self.sorting_analyzer.sorting, @@ -1436,6 +1437,8 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, elif outputs == "by_unit": unit_ids = self.sorting_analyzer.unit_ids spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) + if keep_mask is not None: + spike_vector = spike_vector[keep_mask] spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) data_by_units = {} for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): From bb46f27ad9f719bfcd0db25fae1e55e4c2cbbe8d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 13 Jan 2026 11:59:57 +0100 Subject: [PATCH 06/40] Update src/spikeinterface/core/sorting_tools.py Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> --- src/spikeinterface/core/sorting_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 9a9a3670ef..4695f9b289 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -265,7 +265,7 @@ def select_sorting_periods_mask(sorting: BaseSorting, periods): return keep_mask -def select_sorting_periods(sorting: BaseSorting, periods): +def select_sorting_periods(sorting: BaseSorting, periods) -> BaseSorting: """ Returns a new sorting object, restricted to the given periods of dtype unit_period_dtype. From 121a0b19c3c435fa3a3f7bd64508eb440b371393 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 13 Jan 2026 12:00:58 +0100 Subject: [PATCH 07/40] Apply suggestion from @chrishalcrow Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> --- src/spikeinterface/core/sorting_tools.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 4695f9b289..75e25115ae 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -271,7 +271,6 @@ def select_sorting_periods(sorting: BaseSorting, periods) -> BaseSorting: Parameters ---------- - S periods : numpy.array of unit_period_dtype Periods (segment_index, start_sample_index, end_sample_index, unit_index) on which to restrict the sorting. From cbf3213a4c3769eae38f5203a42025531e80f0dd Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 13 Jan 2026 13:06:43 +0100 Subject: [PATCH 08/40] refactor presence ratio and drift metrics to use periods properly --- .../metrics/quality/misc_metrics.py | 166 ++++++++++-------- .../quality/tests/test_metrics_functions.py | 5 +- src/spikeinterface/metrics/quality/utils.py | 47 ----- .../metrics/spiketrain/metrics.py | 10 +- src/spikeinterface/metrics/utils.py | 121 +++++++++++++ 5 files changed, 220 insertions(+), 129 deletions(-) delete mode 100644 src/spikeinterface/metrics/quality/utils.py create mode 100644 src/spikeinterface/metrics/utils.py diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 4a7ef04554..e720477ee6 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -19,15 +19,14 @@ from spikeinterface.core.analyzer_extension_core import BaseMetric from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from spikeinterface.postprocessing import correlogram_for_one_segment -from spikeinterface.core import SortingAnalyzer, get_noise_levels, select_segment_sorting +from spikeinterface.core import SortingAnalyzer, get_noise_levels from spikeinterface.core.template_tools import ( get_template_extremum_channel, get_template_extremum_amplitude, get_dense_templates_array, ) -from spikeinterface.core.node_pipeline import base_period_dtype - -from ..spiketrain.metrics import NumSpikes, FiringRate +from spikeinterface.metrics.spiketrain.metrics import NumSpikes, FiringRate +from spikeinterface.metrics.utils import compute_bin_edges_per_unit, compute_total_durations_per_unit numba_spec = importlib.util.find_spec("numba") if numba_spec is not None: @@ -74,12 +73,16 @@ def compute_presence_ratios( unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() - seg_lengths = [sorting_analyzer.get_num_samples(i) for i in range(num_segs)] - total_length = sorting_analyzer.get_total_samples() - total_duration = sorting_analyzer.get_total_duration() + segment_samples = [sorting_analyzer.get_num_samples(i) for i in range(num_segs)] + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) + total_samples = np.sum(segment_samples) bin_duration_samples = int((bin_duration_s * sorting_analyzer.sampling_frequency)) - num_bin_edges = total_length // bin_duration_samples + 1 - bin_edges = np.arange(num_bin_edges) * bin_duration_samples + bin_edges_per_unit = compute_bin_edges_per_unit( + sorting, + segment_samples=segment_samples, + periods=periods, + bin_duration_s=bin_duration_s, + ) mean_fr_ratio_thresh = float(mean_fr_ratio_thresh) if mean_fr_ratio_thresh < 0: @@ -90,7 +93,7 @@ def compute_presence_ratios( warnings.warn("`mean_fr_ratio_thres` parameter above 1 might lead to low presence ratios.") presence_ratios = {} - if total_length < bin_duration_samples: + if total_samples < bin_duration_samples: warnings.warn( f"Bin duration of {bin_duration_s}s is larger than recording duration. " f"Presence ratios are set to NaN." ) @@ -98,9 +101,15 @@ def compute_presence_ratios( else: for unit_id in unit_ids: spike_train = [] + bin_edges = bin_edges_per_unit[unit_id] + if len(bin_edges) < 2: + presence_ratios[unit_id] = 0.0 + continue + total_duration = total_durations[unit_id] + for segment_index in range(num_segs): st = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - st = st + np.sum(seg_lengths[:segment_index]) + st = st + np.sum(segment_samples[:segment_index]) spike_train.append(st) spike_train = np.concatenate(spike_train) @@ -109,7 +118,6 @@ def compute_presence_ratios( presence_ratios[unit_id] = presence_ratio( spike_train, - total_length, bin_edges=bin_edges, bin_n_spikes_thres=bin_n_spikes_thres, ) @@ -250,7 +258,7 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() - total_duration_s = sorting_analyzer.get_total_duration() + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) fs = sorting_analyzer.sampling_frequency isi_threshold_s = isi_threshold_ms / 1000 @@ -271,7 +279,8 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 if not any([len(train) > 0 for train in spike_train_list]): continue - ratio, _, count = isi_violations(spike_train_list, total_duration_s, isi_threshold_s, min_isi_s) + total_duration = total_durations[unit_id] + ratio, _, count = isi_violations(spike_train_list, total_duration, isi_threshold_s, min_isi_s) isi_violations_ratio[unit_id] = ratio isi_violations_count[unit_id] = count @@ -449,7 +458,7 @@ def compute_sliding_rp_violations( This code was adapted from: https://github.com/SteinmetzLab/slidingRefractory/blob/1.0.0/python/slidingRP/metrics.py """ - duration = sorting_analyzer.get_total_duration() + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) sorting = sorting_analyzer.sorting sorting = sorting.select_periods(periods=periods) @@ -477,6 +486,7 @@ def compute_sliding_rp_violations( contamination[unit_id] = np.nan continue + duration = total_durations[unit_id] contamination[unit_id] = slidingRP_violations( spike_train_list, fs, @@ -582,6 +592,7 @@ class Synchrony(BaseMetric): } +# TODO: refactor for periods def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percentiles=(5, 95), periods=None): """ Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution @@ -659,6 +670,7 @@ class FiringRange(BaseMetric): } +# TODO: refactor for periods def compute_amplitude_cv_metrics( sorting_analyzer, unit_ids=None, @@ -710,13 +722,14 @@ def compute_amplitude_cv_metrics( "spike_amplitudes", "amplitude_scalings", ), "Invalid amplitude_extension. It can be either 'spike_amplitudes' or 'amplitude_scalings'" - sorting = sorting_analyzer.sorting - total_duration = sorting_analyzer.get_total_duration() - spikes = sorting.to_spike_vector() - num_spikes = sorting.count_num_spikes_per_unit(outputs="dict") if unit_ids is None: unit_ids = sorting.unit_ids + sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) + spikes = sorting.to_spike_vector() + num_spikes = sorting.count_num_spikes_per_unit(outputs="dict") amps = sorting_analyzer.get_extension(amplitude_extension).get_data(periods=periods) # precompute segment slice @@ -729,6 +742,7 @@ def compute_amplitude_cv_metrics( all_unit_ids = list(sorting.unit_ids) amplitude_cv_medians, amplitude_cv_ranges = {}, {} for unit_id in unit_ids: + total_duration = total_durations[unit_id] firing_rate = num_spikes[unit_id] / total_duration temporal_bin_size_samples = int( (average_num_spikes_per_bin / firing_rate) * sorting_analyzer.sampling_frequency @@ -1078,34 +1092,30 @@ def compute_drift_metrics( check_has_required_extensions("drift", sorting_analyzer) res = namedtuple("drift_metrics", ["drift_ptp", "drift_std", "drift_mad"]) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting.unit_ids spike_locations_ext = sorting_analyzer.get_extension("spike_locations") spike_locations = spike_locations_ext.get_data(periods=periods) - spikes = sorting.to_spike_vector() - spike_locations_by_unit = {} - for unit_id in unit_ids: - unit_index = sorting.id_to_index(unit_id) - # TODO @alessio this is very slow this sjould be done with spike_vector_to_indices() in code - spike_mask = spikes["unit_index"] == unit_index - spike_locations_by_unit[unit_id] = spike_locations[spike_mask] + spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit", concatenated=True, periods=periods) + segment_samples = [sorting_analyzer.get_num_samples(i) for i in range(sorting_analyzer.get_num_segments())] interval_samples = int(interval_s * sorting_analyzer.sampling_frequency) assert direction in spike_locations.dtype.names, ( f"Direction {direction} is invalid. Available directions: " f"{spike_locations.dtype.names}" ) - total_duration = sorting_analyzer.get_total_duration() - if total_duration < min_num_bins * interval_s: - warnings.warn( - "The recording is too short given the specified 'interval_s' and " - "'min_num_bins'. Drift metrics will be set to NaN" - ) - empty_dict = {unit_id: np.nan for unit_id in unit_ids} - if return_positions: - return res(empty_dict, empty_dict, empty_dict), np.nan - else: - return res(empty_dict, empty_dict, empty_dict) + # total_duration = sorting_analyzer.get_total_duration() + # if total_duration < min_num_bins * interval_s: + # warnings.warn( + # "The recording is too short given the specified 'interval_s' and " + # "'min_num_bins'. Drift metrics will be set to NaN" + # ) + # empty_dict = {unit_id: np.nan for unit_id in unit_ids} + # if return_positions: + # return res(empty_dict, empty_dict, empty_dict), np.nan + # else: + # return res(empty_dict, empty_dict, empty_dict) # we need drift_ptps = {} @@ -1113,45 +1123,50 @@ def compute_drift_metrics( drift_mads = {} # reference positions are the medians across segments - reference_positions = np.zeros(len(unit_ids)) - for i, unit_id in enumerate(unit_ids): - unit_ind = sorting.id_to_index(unit_id) - reference_positions[i] = np.median(spike_locations_by_unit[unit_id][direction]) + reference_positions = {} + for unit_id in unit_ids: + reference_positions[unit_id] = np.median(spike_locations_by_unit[unit_id][direction]) # now compute median positions and concatenate them over segments median_position_segments = None - for segment_index in range(sorting_analyzer.get_num_segments()): - seg_length = sorting_analyzer.get_num_samples(segment_index) - num_bin_edges = seg_length // interval_samples + 1 - bins = np.arange(num_bin_edges) * interval_samples - spike_vector = sorting.to_spike_vector() - - # retrieve spikes in segment - i0, i1 = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1]) - spikes_in_segment = spike_vector[i0:i1] - spike_locations_in_segment = spike_locations[i0:i1] - - # compute median positions (if less than min_spikes_per_interval, median position is 0) - median_positions = np.nan * np.zeros((len(unit_ids), num_bin_edges - 1)) - for bin_index, (start_frame, end_frame) in enumerate(zip(bins[:-1], bins[1:])): - i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame]) - spikes_in_bin = spikes_in_segment[i0:i1] - spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction] - - for i, unit_id in enumerate(unit_ids): - unit_ind = sorting.id_to_index(unit_id) - mask = spikes_in_bin["unit_index"] == unit_ind - if np.sum(mask) >= min_spikes_per_interval: - median_positions[i, bin_index] = np.median(spike_locations_in_bin[mask]) - if median_position_segments is None: - median_position_segments = median_positions - else: - median_position_segments = np.hstack((median_position_segments, median_positions)) + spike_vector = sorting.to_spike_vector() + bin_edges_for_units = compute_bin_edges_per_unit( + sorting, + segment_samples=segment_samples, + periods=periods, + bin_duration_s=interval_s, + ) + + median_positions_per_unit = {} + for i, unit in enumerate(unit_ids): + bins = bin_edges_for_units[unit] + num_bins = len(bins) - 1 + if num_bins < min_num_bins: + warnings.warn( + f"Unit {unit} has only {num_bins} bins given the specified 'interval_s' and " + f"'min_num_bins'. Drift metrics will be set to NaN" + ) + drift_ptps[unit] = np.nan + drift_stds[unit] = np.nan + drift_mads[unit] = np.nan + continue + + bin_spike_indices = np.searchsorted(spike_vector["sample_index"], bins) + median_positions = np.nan * np.zeros(num_bins) + for bin_index, (i0, i1) in enumerate(zip(bin_spike_indices[:-1], bin_spike_indices[1:])): + spikes_in_bin = spike_vector[i0:i1] + spike_locations_in_bin = spike_locations[i0:i1][direction] - # finally, compute deviations and drifts - position_diffs = median_position_segments - reference_positions[:, None] - for i, unit_id in enumerate(unit_ids): - position_diff = position_diffs[i] + unit_index = sorting_analyzer.sorting.id_to_index(unit) + mask = spikes_in_bin["unit_index"] == unit_index + if np.sum(mask) >= min_spikes_per_interval: + median_positions[bin_index] = np.median(spike_locations_in_bin[mask]) + else: + median_positions[bin_index] = np.nan + median_positions_per_unit[unit] = median_positions + + # now compute deviations and drifts for this unit + position_diff = median_positions - reference_positions[unit_id] if np.any(np.isnan(position_diff)): # deal with nans: if more than 50% nans --> set to nan if np.sum(np.isnan(position_diff)) > min_fraction_valid_intervals * len(position_diff): @@ -1169,8 +1184,9 @@ def compute_drift_metrics( drift_ptps[unit_id] = ptp_drift drift_stds[unit_id] = std_drift drift_mads[unit_id] = mad_drift + if return_positions: - outs = res(drift_ptps, drift_stds, drift_mads), median_positions + outs = res(drift_ptps, drift_stds, drift_mads), median_positions_per_unit else: outs = res(drift_ptps, drift_stds, drift_mads) return outs @@ -1385,7 +1401,7 @@ def check_has_required_extensions(metric_name, sorting_analyzer): ### LOW-LEVEL FUNCTIONS ### -def presence_ratio(spike_train, total_length, bin_edges=None, num_bin_edges=None, bin_n_spikes_thres=0): +def presence_ratio(spike_train, bin_edges=None, num_bin_edges=None, bin_n_spikes_thres=0): """ Calculate the presence ratio for a single unit. @@ -1393,8 +1409,6 @@ def presence_ratio(spike_train, total_length, bin_edges=None, num_bin_edges=None ---------- spike_train : np.ndarray Spike times for this unit, in samples. - total_length : int - Total length of the recording in samples. bin_edges : np.array, optional Pre-computed bin edges (mutually exclusive with num_bin_edges). num_bin_edges : int, optional diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index c0dd6c6033..57516d6bc3 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -12,11 +12,8 @@ synthesize_random_firings, ) -from spikeinterface.metrics.quality.utils import create_ground_truth_pc_distributions +from spikeinterface.metrics.utils import create_ground_truth_pc_distributions -# from spikeinterface.metrics.quality_metric_list import ( -# _misc_metric_name_to_func, -# ) from spikeinterface.metrics.quality import ( get_quality_metric_list, diff --git a/src/spikeinterface/metrics/quality/utils.py b/src/spikeinterface/metrics/quality/utils.py deleted file mode 100644 index 844a7da7f5..0000000000 --- a/src/spikeinterface/metrics/quality/utils.py +++ /dev/null @@ -1,47 +0,0 @@ -from __future__ import annotations - -import numpy as np - - -def create_ground_truth_pc_distributions(center_locations, total_points): - """ - Simulate PCs as multivariate Gaussians, for testing PC-based quality metrics - Values are created for only one channel and vary along one dimension. - - Parameters - ---------- - center_locations : array-like (units, ) or (channels, units) - Mean of the multivariate gaussian at each channel for each unit. - total_points : array-like - Number of points in each unit distribution. - - Returns - ------- - all_pcs : numpy.ndarray - PC scores for each point. - all_labels : numpy.array - Labels for each point. - """ - from scipy.stats import multivariate_normal - - np.random.seed(0) - - if len(np.array(center_locations).shape) == 1: - distributions = [ - multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) - for center, size in zip(center_locations, total_points) - ] - all_pcs = np.concatenate(distributions, axis=0) - - else: - all_pcs = np.empty((np.sum(total_points), 3, center_locations.shape[0])) - for channel in range(center_locations.shape[0]): - distributions = [ - multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) - for center, size in zip(center_locations[channel], total_points) - ] - all_pcs[:, :, channel] = np.concatenate(distributions, axis=0) - - all_labels = np.concatenate([np.ones((total_points[i],), dtype="int") * i for i in range(len(total_points))]) - - return all_pcs, all_labels diff --git a/src/spikeinterface/metrics/spiketrain/metrics.py b/src/spikeinterface/metrics/spiketrain/metrics.py index ba66d0671c..0ddb5fabe7 100644 --- a/src/spikeinterface/metrics/spiketrain/metrics.py +++ b/src/spikeinterface/metrics/spiketrain/metrics.py @@ -2,7 +2,7 @@ from spikeinterface.core.analyzer_extension_core import BaseMetric -def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): +def compute_num_spikes(sorting_analyzer, unit_ids=None, periods=None, **kwargs): """ Compute the number of spike across segments. @@ -12,6 +12,8 @@ def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): A SortingAnalyzer object. unit_ids : list or None The list of unit ids to compute the number of spikes. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) Returns ------- @@ -20,6 +22,7 @@ def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): """ sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods) if unit_ids is None: unit_ids = sorting.unit_ids num_segs = sorting.get_num_segments() @@ -43,7 +46,7 @@ class NumSpikes(BaseMetric): metric_columns = {"num_spikes": int} -def compute_firing_rates(sorting_analyzer, unit_ids=None): +def compute_firing_rates(sorting_analyzer, unit_ids=None, periods=None): """ Compute the firing rate across segments. @@ -53,6 +56,8 @@ def compute_firing_rates(sorting_analyzer, unit_ids=None): A SortingAnalyzer object. unit_ids : list or None The list of unit ids to compute the firing rate. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) Returns ------- @@ -61,6 +66,7 @@ def compute_firing_rates(sorting_analyzer, unit_ids=None): """ sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods) if unit_ids is None: unit_ids = sorting.unit_ids total_duration = sorting_analyzer.get_total_duration() diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py new file mode 100644 index 0000000000..beb9b505ff --- /dev/null +++ b/src/spikeinterface/metrics/utils.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import numpy as np + + +def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, periods=None): + """ + Compute bin edges for units, optionally taking into account periods. + + Parameters + ---------- + sorting : Sorting + Sorting object containing unit information. + segment_samples : list or array-like + Number of samples in each segment. + bin_duration_s : float, default: 1 + Duration of each bin in seconds + periods : array of unit_period_dtype, default: None + Periods to consider for each unit + """ + bin_edges_for_units = {} + num_segments = len(segment_samples) + bin_duration_samples = int(bin_duration_s * sorting.sampling_frequency) + + if periods is not None: + for unit_id in sorting.unit_ids: + unit_index = sorting.id_to_index(unit_id) + periods_unit = periods[periods["unit_index"] == unit_index] + bin_edges = [] + for seg_index in range(num_segments): + seg_periods = periods_unit[periods_unit["segment_index"] == seg_index] + if len(seg_periods) == 0: + continue + seg_start = np.sum(segment_samples[:seg_index]) + for period in seg_periods: + start_sample = seg_start + period["start_sample_index"] + end_sample = seg_start + period["end_sample_index"] + bin_edges.extend(np.arange(start_sample, end_sample, bin_duration_samples)) + bin_edges_for_units[unit_id] = np.array(bin_edges) + else: + total_length = np.sum(segment_samples) + for unit_id in sorting.unit_ids: + bin_edges_for_units[unit_id] = np.arange(0, total_length, bin_duration_samples) * bin_duration_samples + return bin_edges_for_units + + +def compute_total_durations_per_unit(sorting_analyzer, periods=None): + """ + Compute total duration for each unit, optionally taking into account periods. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object. + periods : array of unit_period_dtype, default: None + Periods to consider for each unit. + + Returns + ------- + dict + Total duration for each unit. + """ + if periods is not None: + total_durations = {} + sorting = sorting_analyzer.sorting + for unit_id in sorting.unit_ids: + unit_index = sorting.id_to_index(unit_id) + periods_unit = periods[periods["unit_index"] == unit_index] + total_duration = 0 + for period in periods_unit: + total_duration += period["end_sample_index"] - period["start_sample_index"] + total_durations[unit_id] = total_duration / sorting.sampling_frequency + else: + total_durations = { + unit_id: sorting_analyzer.get_total_duration_per_unit() for unit_id in sorting_analyzer.unit_ids + } + return total_durations + + +def create_ground_truth_pc_distributions(center_locations, total_points): + """ + Simulate PCs as multivariate Gaussians, for testing PC-based quality metrics + Values are created for only one channel and vary along one dimension. + + Parameters + ---------- + center_locations : array-like (units, ) or (channels, units) + Mean of the multivariate gaussian at each channel for each unit. + total_points : array-like + Number of points in each unit distribution. + + Returns + ------- + all_pcs : numpy.ndarray + PC scores for each point. + all_labels : numpy.array + Labels for each point. + """ + from scipy.stats import multivariate_normal + + np.random.seed(0) + + if len(np.array(center_locations).shape) == 1: + distributions = [ + multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) + for center, size in zip(center_locations, total_points) + ] + all_pcs = np.concatenate(distributions, axis=0) + + else: + all_pcs = np.empty((np.sum(total_points), 3, center_locations.shape[0])) + for channel in range(center_locations.shape[0]): + distributions = [ + multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) + for center, size in zip(center_locations[channel], total_points) + ] + all_pcs[:, :, channel] = np.concatenate(distributions, axis=0) + + all_labels = np.concatenate([np.ones((total_points[i],), dtype="int") * i for i in range(len(total_points))]) + + return all_pcs, all_labels From 4409aa5fc1dcd1d5fb93d3a075a598df0c18113f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 13 Jan 2026 13:16:15 +0100 Subject: [PATCH 09/40] Fix rp_violations --- .../metrics/quality/misc_metrics.py | 11 +++-- src/spikeinterface/metrics/utils.py | 41 ++++++++++++++----- 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index e720477ee6..74aca85dce 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -26,7 +26,11 @@ get_dense_templates_array, ) from spikeinterface.metrics.spiketrain.metrics import NumSpikes, FiringRate -from spikeinterface.metrics.utils import compute_bin_edges_per_unit, compute_total_durations_per_unit +from spikeinterface.metrics.utils import ( + compute_bin_edges_per_unit, + compute_total_durations_per_unit, + compute_total_samples_per_unit, +) numba_spec = importlib.util.find_spec("numba") if numba_spec is not None: @@ -366,7 +370,7 @@ def compute_refrac_period_violations( if unit_ids is None: unit_ids = sorting_analyzer.unit_ids - num_spikes = compute_num_spikes(sorting_analyzer) + num_spikes = sorting.count_num_spikes_per_unit() t_c = int(round(censored_period_ms * fs * 1e-3)) t_r = int(round(refractory_period_ms * fs * 1e-3)) @@ -377,7 +381,7 @@ def compute_refrac_period_violations( spike_labels = spikes[seg_index]["unit_index"].astype(np.int32) _compute_rp_violations_numba(nb_rp_violations, spike_times, spike_labels, t_c, t_r) - T = sorting_analyzer.get_total_samples() + total_samples = compute_total_samples_per_unit(sorting_analyzer, periods=periods) nb_violations = {} rp_contamination = {} @@ -388,6 +392,7 @@ def compute_refrac_period_violations( nb_violations[unit_id] = n_v = nb_rp_violations[unit_index] N = num_spikes[unit_id] + T = total_samples[unit_id] if N == 0: rp_contamination[unit_id] = np.nan else: diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py index beb9b505ff..446f9ce471 100644 --- a/src/spikeinterface/metrics/utils.py +++ b/src/spikeinterface/metrics/utils.py @@ -44,9 +44,9 @@ def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, per return bin_edges_for_units -def compute_total_durations_per_unit(sorting_analyzer, periods=None): +def get_total_samples_per_unit(sorting_analyzer, periods=None): """ - Compute total duration for each unit, optionally taking into account periods. + Get total number of samples for each unit, optionally taking into account periods. Parameters ---------- @@ -58,22 +58,43 @@ def compute_total_durations_per_unit(sorting_analyzer, periods=None): Returns ------- dict - Total duration for each unit. + Total number of samples for each unit. """ if periods is not None: - total_durations = {} + total_samples = {} sorting = sorting_analyzer.sorting for unit_id in sorting.unit_ids: unit_index = sorting.id_to_index(unit_id) periods_unit = periods[periods["unit_index"] == unit_index] - total_duration = 0 + num_samples_in_period = 0 for period in periods_unit: - total_duration += period["end_sample_index"] - period["start_sample_index"] - total_durations[unit_id] = total_duration / sorting.sampling_frequency + num_samples_in_period += period["end_sample_index"] - period["start_sample_index"] + total_samples[unit_id] = num_samples_in_period else: - total_durations = { - unit_id: sorting_analyzer.get_total_duration_per_unit() for unit_id in sorting_analyzer.unit_ids - } + total_samples = {unit_id: sorting_analyzer.get_total_samples() for unit_id in sorting_analyzer.unit_ids} + return total_samples + + +def compute_total_durations_per_unit(sorting_analyzer, periods=None): + """ + Compute total duration for each unit, optionally taking into account periods. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object. + periods : array of unit_period_dtype, default: None + Periods to consider for each unit. + + Returns + ------- + dict + Total duration for each unit. + """ + total_samples = get_total_samples_per_unit(sorting_analyzer, periods=periods) + total_durations = { + unit_id: samples / sorting_analyzer.sorting.sampling_frequency for unit_id, samples in total_samples.items() + } return total_durations From 71f8668c5f8eb25d50c53a614f81552947ab1ede Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 13 Jan 2026 13:27:25 +0100 Subject: [PATCH 10/40] implement firing range and fix drift --- .../metrics/quality/misc_metrics.py | 81 ++++++++----------- 1 file changed, 33 insertions(+), 48 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 74aca85dce..b30ab068fe 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -597,7 +597,6 @@ class Synchrony(BaseMetric): } -# TODO: refactor for periods def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percentiles=(5, 95), periods=None): """ Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution @@ -630,6 +629,9 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent bin_size_samples = int(bin_size_s * sampling_frequency) sorting = sorting_analyzer.sorting sorting = sorting.select_periods(periods=periods) + segment_samples = [ + sorting_analyzer.get_num_samples(segment_index) for segment_index in range(sorting_analyzer.get_num_segments()) + ] if unit_ids is None: unit_ids = sorting.unit_ids @@ -645,15 +647,25 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent # for each segment, we compute the firing rate histogram and we concatenate them firing_rate_histograms = {unit_id: np.array([], dtype=float) for unit_id in sorting.unit_ids} - for segment_index in range(sorting_analyzer.get_num_segments()): - num_samples = sorting_analyzer.get_num_samples(segment_index) - edges = np.arange(0, num_samples + 1, bin_size_samples) + bin_edges_per_unit = compute_bin_edges_per_unit( + sorting, + segment_samples=segment_samples, + periods=periods, + bin_duration_s=bin_size_s, + ) + for unit_id in unit_ids: + bin_edges = bin_edges_per_unit[unit_id] - for unit_id in unit_ids: - spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - spike_counts, _ = np.histogram(spike_times, bins=edges) - firing_rates = spike_counts / bin_size_s - firing_rate_histograms[unit_id] = np.concatenate((firing_rate_histograms[unit_id], firing_rates)) + # we can concatenate spike trains across segments adding the cumulative number of samples + # as offset, since bin edges are already cumulative + for segment_index in range(sorting_analyzer.get_num_segments()): + st = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + st = st + np.sum(segment_samples[:segment_index]) + spike_train.append(st) + spike_train = np.concatenate(spike_train) + + spike_counts, _ = np.histogram(spike_train, bins=bin_edges) + firing_rate_histograms[unit_id] = spike_counts / bin_size_s # finally we compute the percentiles firing_ranges = {} @@ -731,9 +743,9 @@ def compute_amplitude_cv_metrics( unit_ids = sorting.unit_ids sorting = sorting_analyzer.sorting sorting = sorting.select_periods(periods=periods) - total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) spikes = sorting.to_spike_vector() + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) num_spikes = sorting.count_num_spikes_per_unit(outputs="dict") amps = sorting_analyzer.get_extension(amplitude_extension).get_data(periods=periods) @@ -1106,21 +1118,9 @@ def compute_drift_metrics( spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit", concatenated=True, periods=periods) segment_samples = [sorting_analyzer.get_num_samples(i) for i in range(sorting_analyzer.get_num_segments())] - interval_samples = int(interval_s * sorting_analyzer.sampling_frequency) assert direction in spike_locations.dtype.names, ( f"Direction {direction} is invalid. Available directions: " f"{spike_locations.dtype.names}" ) - # total_duration = sorting_analyzer.get_total_duration() - # if total_duration < min_num_bins * interval_s: - # warnings.warn( - # "The recording is too short given the specified 'interval_s' and " - # "'min_num_bins'. Drift metrics will be set to NaN" - # ) - # empty_dict = {unit_id: np.nan for unit_id in unit_ids} - # if return_positions: - # return res(empty_dict, empty_dict, empty_dict), np.nan - # else: - # return res(empty_dict, empty_dict, empty_dict) # we need drift_ptps = {} @@ -1133,8 +1133,14 @@ def compute_drift_metrics( reference_positions[unit_id] = np.median(spike_locations_by_unit[unit_id][direction]) # now compute median positions and concatenate them over segments - median_position_segments = None spike_vector = sorting.to_spike_vector() + spike_sample_indices = spike_vector["sample_index"] + # we need to add the cumulative sum of segment samples to have global sample indices + cumulative_segment_samples = np.cumsum([0] + segment_samples[:-1]) + for segment_index in range(sorting_analyzer.get_num_segments()): + seg_mask = spike_vector["segment_index"] == segment_index + spike_sample_indices[seg_mask] += cumulative_segment_samples[segment_index] + bin_edges_for_units = compute_bin_edges_per_unit( sorting, segment_samples=segment_samples, @@ -1143,7 +1149,7 @@ def compute_drift_metrics( ) median_positions_per_unit = {} - for i, unit in enumerate(unit_ids): + for unit in unit_ids: bins = bin_edges_for_units[unit] num_bins = len(bins) - 1 if num_bins < min_num_bins: @@ -1156,7 +1162,9 @@ def compute_drift_metrics( drift_mads[unit] = np.nan continue - bin_spike_indices = np.searchsorted(spike_vector["sample_index"], bins) + # bin_edges are global across segments, so we have to use spike_sample_indices, + # since we offseted them to be global + bin_spike_indices = np.searchsorted(spike_sample_indices, bins) median_positions = np.nan * np.zeros(num_bins) for bin_index, (i0, i1) in enumerate(zip(bin_spike_indices[:-1], bin_spike_indices[1:])): spikes_in_bin = spike_vector[i0:i1] @@ -1783,29 +1791,6 @@ def _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): return synchrony_counts -def _get_amplitudes_by_units(sorting_analyzer, unit_ids, peak_sign): - # used by compute_amplitude_cutoffs and compute_amplitude_medians - - if (spike_amplitudes_extension := sorting_analyzer.get_extension("spike_amplitudes")) is not None: - return spike_amplitudes_extension.get_data(outputs="by_unit", concatenated=True) - - elif sorting_analyzer.has_extension("waveforms"): - amplitudes_by_units = {} - waveforms_ext = sorting_analyzer.get_extension("waveforms") - before = waveforms_ext.nbefore - extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign) - for unit_id in unit_ids: - waveforms = waveforms_ext.get_waveforms_one_unit(unit_id, force_dense=False) - chan_id = extremum_channels_ids[unit_id] - if sorting_analyzer.is_sparse(): - chan_ind = np.where(sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id] == chan_id)[0] - else: - chan_ind = sorting_analyzer.channel_ids_to_indices([chan_id])[0] - amplitudes_by_units[unit_id] = waveforms[:, before, chan_ind] - - return amplitudes_by_units - - if HAVE_NUMBA: import numba From 1ea0d68074a0925d08811d31a8951728346e4c03 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 13 Jan 2026 14:24:19 +0100 Subject: [PATCH 11/40] fix naming issue --- src/spikeinterface/metrics/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py index 446f9ce471..16058e521f 100644 --- a/src/spikeinterface/metrics/utils.py +++ b/src/spikeinterface/metrics/utils.py @@ -44,7 +44,7 @@ def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, per return bin_edges_for_units -def get_total_samples_per_unit(sorting_analyzer, periods=None): +def compute_total_samples_per_unit(sorting_analyzer, periods=None): """ Get total number of samples for each unit, optionally taking into account periods. @@ -91,7 +91,7 @@ def compute_total_durations_per_unit(sorting_analyzer, periods=None): dict Total duration for each unit. """ - total_samples = get_total_samples_per_unit(sorting_analyzer, periods=periods) + total_samples = compute_total_samples_per_unit(sorting_analyzer, periods=periods) total_durations = { unit_id: samples / sorting_analyzer.sorting.sampling_frequency for unit_id, samples in total_samples.items() } From a86c2d36c6ddd823e65f999a1f76f9a8607f0938 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 13 Jan 2026 14:31:11 +0100 Subject: [PATCH 12/40] remove solved todos --- src/spikeinterface/metrics/quality/misc_metrics.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index b30ab068fe..004f0ee56c 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -687,7 +687,6 @@ class FiringRange(BaseMetric): } -# TODO: refactor for periods def compute_amplitude_cv_metrics( sorting_analyzer, unit_ids=None, @@ -1223,7 +1222,6 @@ class Drift(BaseMetric): depend_on = ["spike_locations"] -# TODO def compute_sd_ratio( sorting_analyzer: SortingAnalyzer, unit_ids=None, From 3f93f97618930203aadb724dda7a50a53e57b4b6 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 13 Jan 2026 16:42:16 +0100 Subject: [PATCH 13/40] Implement select_segment_periods in core --- .../core/analyzer_extension_core.py | 12 - .../metrics/quality/misc_metrics.py | 301 +++++++----------- .../quality/tests/test_metrics_functions.py | 5 +- src/spikeinterface/metrics/quality/utils.py | 47 +++ .../metrics/spiketrain/metrics.py | 10 +- 5 files changed, 175 insertions(+), 200 deletions(-) create mode 100644 src/spikeinterface/metrics/quality/utils.py diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 804418a2ff..7ac037a8cd 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -1333,18 +1333,6 @@ class BaseSpikeVectorExtension(AnalyzerExtension): def __init__(self, sorting_analyzer): super().__init__(sorting_analyzer) - self._segment_slices = None - - @property - def segment_slices(self): - if self._segment_slices is None: - segment_slices = [] - spikes = self.sorting_analyzer.sorting.to_spike_vector() - for segment_index in range(self.sorting_analyzer.get_num_segments()): - i0, i1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) - segment_slices.append(slice(i0, i1)) - self._segment_slices = segment_slices - return self._segment_slices def _set_params(self, **kwargs): params = kwargs.copy() diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 004f0ee56c..c6b07da52e 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -25,12 +25,8 @@ get_template_extremum_amplitude, get_dense_templates_array, ) -from spikeinterface.metrics.spiketrain.metrics import NumSpikes, FiringRate -from spikeinterface.metrics.utils import ( - compute_bin_edges_per_unit, - compute_total_durations_per_unit, - compute_total_samples_per_unit, -) + +from ..spiketrain.metrics import NumSpikes, FiringRate numba_spec = importlib.util.find_spec("numba") if numba_spec is not None: @@ -39,9 +35,7 @@ HAVE_NUMBA = False -def compute_presence_ratios( - sorting_analyzer, unit_ids=None, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, periods=None -): +def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0): """ Calculate the presence ratio, the fraction of time the unit is firing above a certain threshold. @@ -57,9 +51,6 @@ def compute_presence_ratios( mean_fr_ratio_thresh : float, default: 0 The unit is considered active in a bin if its firing rate during that bin. is strictly above `mean_fr_ratio_thresh` times its mean firing rate throughout the recording. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -72,21 +63,16 @@ def compute_presence_ratios( To do so, spike trains across segments are concatenated to mimic a continuous segment. """ sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() - segment_samples = [sorting_analyzer.get_num_samples(i) for i in range(num_segs)] - total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) - total_samples = np.sum(segment_samples) + seg_lengths = [sorting_analyzer.get_num_samples(i) for i in range(num_segs)] + total_length = sorting_analyzer.get_total_samples() + total_duration = sorting_analyzer.get_total_duration() bin_duration_samples = int((bin_duration_s * sorting_analyzer.sampling_frequency)) - bin_edges_per_unit = compute_bin_edges_per_unit( - sorting, - segment_samples=segment_samples, - periods=periods, - bin_duration_s=bin_duration_s, - ) + num_bin_edges = total_length // bin_duration_samples + 1 + bin_edges = np.arange(num_bin_edges) * bin_duration_samples mean_fr_ratio_thresh = float(mean_fr_ratio_thresh) if mean_fr_ratio_thresh < 0: @@ -97,7 +83,7 @@ def compute_presence_ratios( warnings.warn("`mean_fr_ratio_thres` parameter above 1 might lead to low presence ratios.") presence_ratios = {} - if total_samples < bin_duration_samples: + if total_length < bin_duration_samples: warnings.warn( f"Bin duration of {bin_duration_s}s is larger than recording duration. " f"Presence ratios are set to NaN." ) @@ -105,15 +91,9 @@ def compute_presence_ratios( else: for unit_id in unit_ids: spike_train = [] - bin_edges = bin_edges_per_unit[unit_id] - if len(bin_edges) < 2: - presence_ratios[unit_id] = 0.0 - continue - total_duration = total_durations[unit_id] - for segment_index in range(num_segs): st = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - st = st + np.sum(segment_samples[:segment_index]) + st = st + np.sum(seg_lengths[:segment_index]) spike_train.append(st) spike_train = np.concatenate(spike_train) @@ -122,6 +102,7 @@ def compute_presence_ratios( presence_ratios[unit_id] = presence_ratio( spike_train, + total_length, bin_edges=bin_edges, bin_n_spikes_thres=bin_n_spikes_thres, ) @@ -201,7 +182,7 @@ class SNR(BaseMetric): depend_on = ["noise_levels", "templates"] -def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5, min_isi_ms=0, periods=None): +def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5, min_isi_ms=0): """ Calculate Inter-Spike Interval (ISI) violations. @@ -223,9 +204,6 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 Minimum possible inter-spike interval, in ms. This is the artificial refractory period enforced. by the data acquisition system or post-processing algorithms. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -257,12 +235,11 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 res = namedtuple("isi_violation", ["isi_violations_ratio", "isi_violations_count"]) sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() - total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) + total_duration_s = sorting_analyzer.get_total_duration() fs = sorting_analyzer.sampling_frequency isi_threshold_s = isi_threshold_ms / 1000 @@ -283,8 +260,7 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 if not any([len(train) > 0 for train in spike_train_list]): continue - total_duration = total_durations[unit_id] - ratio, _, count = isi_violations(spike_train_list, total_duration, isi_threshold_s, min_isi_s) + ratio, _, count = isi_violations(spike_train_list, total_duration_s, isi_threshold_s, min_isi_s) isi_violations_ratio[unit_id] = ratio isi_violations_count[unit_id] = count @@ -304,7 +280,7 @@ class ISIViolation(BaseMetric): def compute_refrac_period_violations( - sorting_analyzer, unit_ids=None, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0, periods=None + sorting_analyzer, unit_ids=None, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0 ): """ Calculate the number of refractory period violations. @@ -324,9 +300,6 @@ def compute_refrac_period_violations( censored_period_ms : float, default: 0.0 The period (in ms) where no 2 spikes can occur (because they are not detected, or because they were removed by another mean). - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -359,8 +332,6 @@ def compute_refrac_period_violations( return None sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods=periods) - fs = sorting_analyzer.sampling_frequency num_units = len(sorting_analyzer.unit_ids) num_segments = sorting_analyzer.get_num_segments() @@ -370,7 +341,7 @@ def compute_refrac_period_violations( if unit_ids is None: unit_ids = sorting_analyzer.unit_ids - num_spikes = sorting.count_num_spikes_per_unit() + num_spikes = compute_num_spikes(sorting_analyzer) t_c = int(round(censored_period_ms * fs * 1e-3)) t_r = int(round(refractory_period_ms * fs * 1e-3)) @@ -381,7 +352,7 @@ def compute_refrac_period_violations( spike_labels = spikes[seg_index]["unit_index"].astype(np.int32) _compute_rp_violations_numba(nb_rp_violations, spike_times, spike_labels, t_c, t_r) - total_samples = compute_total_samples_per_unit(sorting_analyzer, periods=periods) + T = sorting_analyzer.get_total_samples() nb_violations = {} rp_contamination = {} @@ -392,7 +363,6 @@ def compute_refrac_period_violations( nb_violations[unit_id] = n_v = nb_rp_violations[unit_index] N = num_spikes[unit_id] - T = total_samples[unit_id] if N == 0: rp_contamination[unit_id] = np.nan else: @@ -422,7 +392,6 @@ def compute_sliding_rp_violations( exclude_ref_period_below_ms=0.5, max_ref_period_ms=10, contamination_values=None, - periods=None, ): """ Compute sliding refractory period violations, a metric developed by IBL which computes @@ -448,9 +417,6 @@ def compute_sliding_rp_violations( Maximum refractory period to test in ms. contamination_values : 1d array or None, default: None The contamination values to test, If None, it is set to np.arange(0.5, 35, 0.5). - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -463,10 +429,8 @@ def compute_sliding_rp_violations( This code was adapted from: https://github.com/SteinmetzLab/slidingRefractory/blob/1.0.0/python/slidingRP/metrics.py """ - total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) + duration = sorting_analyzer.get_total_duration() sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods=periods) - if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() @@ -491,7 +455,6 @@ def compute_sliding_rp_violations( contamination[unit_id] = np.nan continue - duration = total_durations[unit_id] contamination[unit_id] = slidingRP_violations( spike_train_list, fs, @@ -523,7 +486,7 @@ class SlidingRPViolation(BaseMetric): } -def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None, periods=None): +def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None): """ Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of spikes at the exact same sample index, with synchrony sizes 2, 4 and 8. @@ -541,9 +504,6 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N ------- sync_spike_{X} : dict The synchrony metric for synchrony size X. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. References ---------- @@ -560,7 +520,6 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes]) sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting.unit_ids @@ -597,7 +556,7 @@ class Synchrony(BaseMetric): } -def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percentiles=(5, 95), periods=None): +def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percentiles=(5, 95)): """ Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution computed in non-overlapping time bins. @@ -612,9 +571,6 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent The size of the bin in seconds. percentiles : tuple, default: (5, 95) The percentiles to compute. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -628,11 +584,6 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent sampling_frequency = sorting_analyzer.sampling_frequency bin_size_samples = int(bin_size_s * sampling_frequency) sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods=periods) - segment_samples = [ - sorting_analyzer.get_num_samples(segment_index) for segment_index in range(sorting_analyzer.get_num_segments()) - ] - if unit_ids is None: unit_ids = sorting.unit_ids @@ -647,25 +598,15 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent # for each segment, we compute the firing rate histogram and we concatenate them firing_rate_histograms = {unit_id: np.array([], dtype=float) for unit_id in sorting.unit_ids} - bin_edges_per_unit = compute_bin_edges_per_unit( - sorting, - segment_samples=segment_samples, - periods=periods, - bin_duration_s=bin_size_s, - ) - for unit_id in unit_ids: - bin_edges = bin_edges_per_unit[unit_id] - - # we can concatenate spike trains across segments adding the cumulative number of samples - # as offset, since bin edges are already cumulative - for segment_index in range(sorting_analyzer.get_num_segments()): - st = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - st = st + np.sum(segment_samples[:segment_index]) - spike_train.append(st) - spike_train = np.concatenate(spike_train) + for segment_index in range(sorting_analyzer.get_num_segments()): + num_samples = sorting_analyzer.get_num_samples(segment_index) + edges = np.arange(0, num_samples + 1, bin_size_samples) - spike_counts, _ = np.histogram(spike_train, bins=bin_edges) - firing_rate_histograms[unit_id] = spike_counts / bin_size_s + for unit_id in unit_ids: + spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + spike_counts, _ = np.histogram(spike_times, bins=edges) + firing_rates = spike_counts / bin_size_s + firing_rate_histograms[unit_id] = np.concatenate((firing_rate_histograms[unit_id], firing_rates)) # finally we compute the percentiles firing_ranges = {} @@ -694,7 +635,6 @@ def compute_amplitude_cv_metrics( percentiles=(5, 95), min_num_bins=10, amplitude_extension="spike_amplitudes", - periods=None, ): """ Calculate coefficient of variation of spike amplitudes within defined temporal bins. @@ -718,8 +658,6 @@ def compute_amplitude_cv_metrics( the median and range are set to NaN. amplitude_extension : str, default: "spike_amplitudes" The name of the extension to load the amplitudes from. "spike_amplitudes" or "amplitude_scalings". - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) Returns ------- @@ -738,15 +676,14 @@ def compute_amplitude_cv_metrics( "spike_amplitudes", "amplitude_scalings", ), "Invalid amplitude_extension. It can be either 'spike_amplitudes' or 'amplitude_scalings'" - if unit_ids is None: - unit_ids = sorting.unit_ids sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods=periods) - + total_duration = sorting_analyzer.get_total_duration() spikes = sorting.to_spike_vector() - total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) num_spikes = sorting.count_num_spikes_per_unit(outputs="dict") - amps = sorting_analyzer.get_extension(amplitude_extension).get_data(periods=periods) + if unit_ids is None: + unit_ids = sorting.unit_ids + + amps = sorting_analyzer.get_extension(amplitude_extension).get_data() # precompute segment slice segment_slices = [] @@ -758,7 +695,6 @@ def compute_amplitude_cv_metrics( all_unit_ids = list(sorting.unit_ids) amplitude_cv_medians, amplitude_cv_ranges = {}, {} for unit_id in unit_ids: - total_duration = total_durations[unit_id] firing_rate = num_spikes[unit_id] / total_duration temporal_bin_size_samples = int( (average_num_spikes_per_bin / firing_rate) * sorting_analyzer.sampling_frequency @@ -816,7 +752,6 @@ def compute_amplitude_cutoffs( num_histogram_bins=500, histogram_smoothing_value=3, amplitudes_bins_min_ratio=5, - periods=None, ): """ Calculate approximate fraction of spikes missing from a distribution of amplitudes. @@ -835,9 +770,6 @@ def compute_amplitude_cutoffs( The minimum ratio between number of amplitudes for a unit and the number of bins. If the ratio is less than this threshold, the amplitude_cutoff for the unit is set to NaN. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -873,7 +805,7 @@ def compute_amplitude_cutoffs( invert_amplitudes = True extension = sorting_analyzer.get_extension("amplitude_scalings") - amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, periods=periods) + amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True) for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] @@ -905,7 +837,7 @@ class AmplitudeCutoff(BaseMetric): depend_on = ["spike_amplitudes|amplitude_scalings"] -def compute_amplitude_medians(sorting_analyzer, unit_ids=None, periods=None): +def compute_amplitude_medians(sorting_analyzer, unit_ids=None): """ Compute median of the amplitude distributions (in absolute value). @@ -915,9 +847,6 @@ def compute_amplitude_medians(sorting_analyzer, unit_ids=None, periods=None): A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the amplitude medians. If None, all units are used. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -936,7 +865,7 @@ def compute_amplitude_medians(sorting_analyzer, unit_ids=None, periods=None): all_amplitude_medians = {} amplitude_extension = sorting_analyzer.get_extension("spike_amplitudes") - amplitudes_by_units = amplitude_extension.get_data(outputs="by_unit", concatenated=True, periods=periods) + amplitudes_by_units = amplitude_extension.get_data(outputs="by_unit", concatenated=True) for unit_id in unit_ids: all_amplitude_medians[unit_id] = np.median(amplitudes_by_units[unit_id]) @@ -953,9 +882,7 @@ class AmplitudeMedian(BaseMetric): depend_on = ["spike_amplitudes"] -def compute_noise_cutoffs( - sorting_analyzer, unit_ids=None, high_quantile=0.25, low_quantile=0.1, n_bins=100, periods=None -): +def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, low_quantile=0.1, n_bins=100): """ A metric to determine if a unit's amplitude distribution is cut off as it approaches zero, without assuming a Gaussian distribution. @@ -979,9 +906,6 @@ def compute_noise_cutoffs( Quantile of the amplitude range below which values are treated as "low" (e.g. 0.1 = lower 10%), the test region. n_bins: int, default: 100 The number of bins to use to compute the amplitude histogram. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -1010,7 +934,7 @@ def compute_noise_cutoffs( invert_amplitudes = True extension = sorting_analyzer.get_extension("amplitude_scalings") - amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, periods=periods) + amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True) for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] @@ -1048,7 +972,6 @@ def compute_drift_metrics( min_fraction_valid_intervals=0.5, min_num_bins=2, return_positions=False, - periods=None, ): """ Compute drifts metrics using estimated spike locations. @@ -1083,9 +1006,6 @@ def compute_drift_metrics( min_num_bins : int, default: 2 Minimum number of bins required to return a valid metric value. In case there are less bins, the metric values are set to NaN. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. return_positions : bool, default: False If True, median positions are returned (for debugging). @@ -1108,18 +1028,35 @@ def compute_drift_metrics( check_has_required_extensions("drift", sorting_analyzer) res = namedtuple("drift_metrics", ["drift_ptp", "drift_std", "drift_mad"]) sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting.unit_ids spike_locations_ext = sorting_analyzer.get_extension("spike_locations") - spike_locations = spike_locations_ext.get_data(periods=periods) - spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit", concatenated=True, periods=periods) + spike_locations = spike_locations_ext.get_data() + # spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit") + spikes = sorting.to_spike_vector() + spike_locations_by_unit = {} + for unit_id in unit_ids: + unit_index = sorting.id_to_index(unit_id) + # TODO @alessio this is very slow this sjould be done with spike_vector_to_indices() in code + spike_mask = spikes["unit_index"] == unit_index + spike_locations_by_unit[unit_id] = spike_locations[spike_mask] - segment_samples = [sorting_analyzer.get_num_samples(i) for i in range(sorting_analyzer.get_num_segments())] + interval_samples = int(interval_s * sorting_analyzer.sampling_frequency) assert direction in spike_locations.dtype.names, ( f"Direction {direction} is invalid. Available directions: " f"{spike_locations.dtype.names}" ) + total_duration = sorting_analyzer.get_total_duration() + if total_duration < min_num_bins * interval_s: + warnings.warn( + "The recording is too short given the specified 'interval_s' and " + "'min_num_bins'. Drift metrics will be set to NaN" + ) + empty_dict = {unit_id: np.nan for unit_id in unit_ids} + if return_positions: + return res(empty_dict, empty_dict, empty_dict), np.nan + else: + return res(empty_dict, empty_dict, empty_dict) # we need drift_ptps = {} @@ -1127,58 +1064,45 @@ def compute_drift_metrics( drift_mads = {} # reference positions are the medians across segments - reference_positions = {} - for unit_id in unit_ids: - reference_positions[unit_id] = np.median(spike_locations_by_unit[unit_id][direction]) + reference_positions = np.zeros(len(unit_ids)) + for i, unit_id in enumerate(unit_ids): + unit_ind = sorting.id_to_index(unit_id) + reference_positions[i] = np.median(spike_locations_by_unit[unit_id][direction]) # now compute median positions and concatenate them over segments - spike_vector = sorting.to_spike_vector() - spike_sample_indices = spike_vector["sample_index"] - # we need to add the cumulative sum of segment samples to have global sample indices - cumulative_segment_samples = np.cumsum([0] + segment_samples[:-1]) + median_position_segments = None for segment_index in range(sorting_analyzer.get_num_segments()): - seg_mask = spike_vector["segment_index"] == segment_index - spike_sample_indices[seg_mask] += cumulative_segment_samples[segment_index] - - bin_edges_for_units = compute_bin_edges_per_unit( - sorting, - segment_samples=segment_samples, - periods=periods, - bin_duration_s=interval_s, - ) - - median_positions_per_unit = {} - for unit in unit_ids: - bins = bin_edges_for_units[unit] - num_bins = len(bins) - 1 - if num_bins < min_num_bins: - warnings.warn( - f"Unit {unit} has only {num_bins} bins given the specified 'interval_s' and " - f"'min_num_bins'. Drift metrics will be set to NaN" - ) - drift_ptps[unit] = np.nan - drift_stds[unit] = np.nan - drift_mads[unit] = np.nan - continue - - # bin_edges are global across segments, so we have to use spike_sample_indices, - # since we offseted them to be global - bin_spike_indices = np.searchsorted(spike_sample_indices, bins) - median_positions = np.nan * np.zeros(num_bins) - for bin_index, (i0, i1) in enumerate(zip(bin_spike_indices[:-1], bin_spike_indices[1:])): - spikes_in_bin = spike_vector[i0:i1] - spike_locations_in_bin = spike_locations[i0:i1][direction] - - unit_index = sorting_analyzer.sorting.id_to_index(unit) - mask = spikes_in_bin["unit_index"] == unit_index - if np.sum(mask) >= min_spikes_per_interval: - median_positions[bin_index] = np.median(spike_locations_in_bin[mask]) - else: - median_positions[bin_index] = np.nan - median_positions_per_unit[unit] = median_positions + seg_length = sorting_analyzer.get_num_samples(segment_index) + num_bin_edges = seg_length // interval_samples + 1 + bins = np.arange(num_bin_edges) * interval_samples + spike_vector = sorting.to_spike_vector() + + # retrieve spikes in segment + i0, i1 = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1]) + spikes_in_segment = spike_vector[i0:i1] + spike_locations_in_segment = spike_locations[i0:i1] + + # compute median positions (if less than min_spikes_per_interval, median position is 0) + median_positions = np.nan * np.zeros((len(unit_ids), num_bin_edges - 1)) + for bin_index, (start_frame, end_frame) in enumerate(zip(bins[:-1], bins[1:])): + i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame]) + spikes_in_bin = spikes_in_segment[i0:i1] + spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction] + + for i, unit_id in enumerate(unit_ids): + unit_ind = sorting.id_to_index(unit_id) + mask = spikes_in_bin["unit_index"] == unit_ind + if np.sum(mask) >= min_spikes_per_interval: + median_positions[i, bin_index] = np.median(spike_locations_in_bin[mask]) + if median_position_segments is None: + median_position_segments = median_positions + else: + median_position_segments = np.hstack((median_position_segments, median_positions)) - # now compute deviations and drifts for this unit - position_diff = median_positions - reference_positions[unit_id] + # finally, compute deviations and drifts + position_diffs = median_position_segments - reference_positions[:, None] + for i, unit_id in enumerate(unit_ids): + position_diff = position_diffs[i] if np.any(np.isnan(position_diff)): # deal with nans: if more than 50% nans --> set to nan if np.sum(np.isnan(position_diff)) > min_fraction_valid_intervals * len(position_diff): @@ -1196,9 +1120,8 @@ def compute_drift_metrics( drift_ptps[unit_id] = ptp_drift drift_stds[unit_id] = std_drift drift_mads[unit_id] = mad_drift - if return_positions: - outs = res(drift_ptps, drift_stds, drift_mads), median_positions_per_unit + outs = res(drift_ptps, drift_stds, drift_mads), median_positions else: outs = res(drift_ptps, drift_stds, drift_mads) return outs @@ -1228,7 +1151,6 @@ def compute_sd_ratio( censored_period_ms: float = 4.0, correct_for_drift: bool = True, correct_for_template_itself: bool = True, - periods=None, **kwargs, ): """ @@ -1251,9 +1173,6 @@ def compute_sd_ratio( correct_for_template_itself : bool, default: True If true, will take into account that the template itself impacts the standard deviation of the noise, and will make a rough estimation of what that impact is (and remove it). - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. **kwargs : dict, default: {} Keyword arguments for computing spike amplitudes and extremum channel. @@ -1270,7 +1189,6 @@ def compute_sd_ratio( job_kwargs = fix_job_kwargs(job_kwargs) sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods=periods) censored_period = int(round(censored_period_ms * 1e-3 * sorting_analyzer.sampling_frequency)) if unit_ids is None: @@ -1283,7 +1201,7 @@ def compute_sd_ratio( ) return {unit_id: np.nan for unit_id in unit_ids} - spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data(periods=periods) + spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() if not HAVE_NUMBA: warnings.warn( @@ -1412,7 +1330,7 @@ def check_has_required_extensions(metric_name, sorting_analyzer): ### LOW-LEVEL FUNCTIONS ### -def presence_ratio(spike_train, bin_edges=None, num_bin_edges=None, bin_n_spikes_thres=0): +def presence_ratio(spike_train, total_length, bin_edges=None, num_bin_edges=None, bin_n_spikes_thres=0): """ Calculate the presence ratio for a single unit. @@ -1420,6 +1338,8 @@ def presence_ratio(spike_train, bin_edges=None, num_bin_edges=None, bin_n_spikes ---------- spike_train : np.ndarray Spike times for this unit, in samples. + total_length : int + Total length of the recording in samples. bin_edges : np.array, optional Pre-computed bin edges (mutually exclusive with num_bin_edges). num_bin_edges : int, optional @@ -1789,6 +1709,29 @@ def _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): return synchrony_counts +def _get_amplitudes_by_units(sorting_analyzer, unit_ids, peak_sign): + # used by compute_amplitude_cutoffs and compute_amplitude_medians + + if (spike_amplitudes_extension := sorting_analyzer.get_extension("spike_amplitudes")) is not None: + return spike_amplitudes_extension.get_data(outputs="by_unit", concatenated=True) + + elif sorting_analyzer.has_extension("waveforms"): + amplitudes_by_units = {} + waveforms_ext = sorting_analyzer.get_extension("waveforms") + before = waveforms_ext.nbefore + extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign) + for unit_id in unit_ids: + waveforms = waveforms_ext.get_waveforms_one_unit(unit_id, force_dense=False) + chan_id = extremum_channels_ids[unit_id] + if sorting_analyzer.is_sparse(): + chan_ind = np.where(sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id] == chan_id)[0] + else: + chan_ind = sorting_analyzer.channel_ids_to_indices([chan_id])[0] + amplitudes_by_units[unit_id] = waveforms[:, before, chan_ind] + + return amplitudes_by_units + + if HAVE_NUMBA: import numba diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index 57516d6bc3..c0dd6c6033 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -12,8 +12,11 @@ synthesize_random_firings, ) -from spikeinterface.metrics.utils import create_ground_truth_pc_distributions +from spikeinterface.metrics.quality.utils import create_ground_truth_pc_distributions +# from spikeinterface.metrics.quality_metric_list import ( +# _misc_metric_name_to_func, +# ) from spikeinterface.metrics.quality import ( get_quality_metric_list, diff --git a/src/spikeinterface/metrics/quality/utils.py b/src/spikeinterface/metrics/quality/utils.py new file mode 100644 index 0000000000..844a7da7f5 --- /dev/null +++ b/src/spikeinterface/metrics/quality/utils.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import numpy as np + + +def create_ground_truth_pc_distributions(center_locations, total_points): + """ + Simulate PCs as multivariate Gaussians, for testing PC-based quality metrics + Values are created for only one channel and vary along one dimension. + + Parameters + ---------- + center_locations : array-like (units, ) or (channels, units) + Mean of the multivariate gaussian at each channel for each unit. + total_points : array-like + Number of points in each unit distribution. + + Returns + ------- + all_pcs : numpy.ndarray + PC scores for each point. + all_labels : numpy.array + Labels for each point. + """ + from scipy.stats import multivariate_normal + + np.random.seed(0) + + if len(np.array(center_locations).shape) == 1: + distributions = [ + multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) + for center, size in zip(center_locations, total_points) + ] + all_pcs = np.concatenate(distributions, axis=0) + + else: + all_pcs = np.empty((np.sum(total_points), 3, center_locations.shape[0])) + for channel in range(center_locations.shape[0]): + distributions = [ + multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) + for center, size in zip(center_locations[channel], total_points) + ] + all_pcs[:, :, channel] = np.concatenate(distributions, axis=0) + + all_labels = np.concatenate([np.ones((total_points[i],), dtype="int") * i for i in range(len(total_points))]) + + return all_pcs, all_labels diff --git a/src/spikeinterface/metrics/spiketrain/metrics.py b/src/spikeinterface/metrics/spiketrain/metrics.py index 0ddb5fabe7..ba66d0671c 100644 --- a/src/spikeinterface/metrics/spiketrain/metrics.py +++ b/src/spikeinterface/metrics/spiketrain/metrics.py @@ -2,7 +2,7 @@ from spikeinterface.core.analyzer_extension_core import BaseMetric -def compute_num_spikes(sorting_analyzer, unit_ids=None, periods=None, **kwargs): +def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): """ Compute the number of spike across segments. @@ -12,8 +12,6 @@ def compute_num_spikes(sorting_analyzer, unit_ids=None, periods=None, **kwargs): A SortingAnalyzer object. unit_ids : list or None The list of unit ids to compute the number of spikes. If None, all units are used. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) Returns ------- @@ -22,7 +20,6 @@ def compute_num_spikes(sorting_analyzer, unit_ids=None, periods=None, **kwargs): """ sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods) if unit_ids is None: unit_ids = sorting.unit_ids num_segs = sorting.get_num_segments() @@ -46,7 +43,7 @@ class NumSpikes(BaseMetric): metric_columns = {"num_spikes": int} -def compute_firing_rates(sorting_analyzer, unit_ids=None, periods=None): +def compute_firing_rates(sorting_analyzer, unit_ids=None): """ Compute the firing rate across segments. @@ -56,8 +53,6 @@ def compute_firing_rates(sorting_analyzer, unit_ids=None, periods=None): A SortingAnalyzer object. unit_ids : list or None The list of unit ids to compute the firing rate. If None, all units are used. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) Returns ------- @@ -66,7 +61,6 @@ def compute_firing_rates(sorting_analyzer, unit_ids=None, periods=None): """ sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods) if unit_ids is None: unit_ids = sorting.unit_ids total_duration = sorting_analyzer.get_total_duration() From cd854567b45d0eddc16275431ef5c6044188c0ec Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 13 Jan 2026 16:43:53 +0100 Subject: [PATCH 14/40] remove utils --- src/spikeinterface/metrics/utils.py | 142 ---------------------------- 1 file changed, 142 deletions(-) delete mode 100644 src/spikeinterface/metrics/utils.py diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py deleted file mode 100644 index 16058e521f..0000000000 --- a/src/spikeinterface/metrics/utils.py +++ /dev/null @@ -1,142 +0,0 @@ -from __future__ import annotations - -import numpy as np - - -def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, periods=None): - """ - Compute bin edges for units, optionally taking into account periods. - - Parameters - ---------- - sorting : Sorting - Sorting object containing unit information. - segment_samples : list or array-like - Number of samples in each segment. - bin_duration_s : float, default: 1 - Duration of each bin in seconds - periods : array of unit_period_dtype, default: None - Periods to consider for each unit - """ - bin_edges_for_units = {} - num_segments = len(segment_samples) - bin_duration_samples = int(bin_duration_s * sorting.sampling_frequency) - - if periods is not None: - for unit_id in sorting.unit_ids: - unit_index = sorting.id_to_index(unit_id) - periods_unit = periods[periods["unit_index"] == unit_index] - bin_edges = [] - for seg_index in range(num_segments): - seg_periods = periods_unit[periods_unit["segment_index"] == seg_index] - if len(seg_periods) == 0: - continue - seg_start = np.sum(segment_samples[:seg_index]) - for period in seg_periods: - start_sample = seg_start + period["start_sample_index"] - end_sample = seg_start + period["end_sample_index"] - bin_edges.extend(np.arange(start_sample, end_sample, bin_duration_samples)) - bin_edges_for_units[unit_id] = np.array(bin_edges) - else: - total_length = np.sum(segment_samples) - for unit_id in sorting.unit_ids: - bin_edges_for_units[unit_id] = np.arange(0, total_length, bin_duration_samples) * bin_duration_samples - return bin_edges_for_units - - -def compute_total_samples_per_unit(sorting_analyzer, periods=None): - """ - Get total number of samples for each unit, optionally taking into account periods. - - Parameters - ---------- - sorting_analyzer : SortingAnalyzer - The sorting analyzer object. - periods : array of unit_period_dtype, default: None - Periods to consider for each unit. - - Returns - ------- - dict - Total number of samples for each unit. - """ - if periods is not None: - total_samples = {} - sorting = sorting_analyzer.sorting - for unit_id in sorting.unit_ids: - unit_index = sorting.id_to_index(unit_id) - periods_unit = periods[periods["unit_index"] == unit_index] - num_samples_in_period = 0 - for period in periods_unit: - num_samples_in_period += period["end_sample_index"] - period["start_sample_index"] - total_samples[unit_id] = num_samples_in_period - else: - total_samples = {unit_id: sorting_analyzer.get_total_samples() for unit_id in sorting_analyzer.unit_ids} - return total_samples - - -def compute_total_durations_per_unit(sorting_analyzer, periods=None): - """ - Compute total duration for each unit, optionally taking into account periods. - - Parameters - ---------- - sorting_analyzer : SortingAnalyzer - The sorting analyzer object. - periods : array of unit_period_dtype, default: None - Periods to consider for each unit. - - Returns - ------- - dict - Total duration for each unit. - """ - total_samples = compute_total_samples_per_unit(sorting_analyzer, periods=periods) - total_durations = { - unit_id: samples / sorting_analyzer.sorting.sampling_frequency for unit_id, samples in total_samples.items() - } - return total_durations - - -def create_ground_truth_pc_distributions(center_locations, total_points): - """ - Simulate PCs as multivariate Gaussians, for testing PC-based quality metrics - Values are created for only one channel and vary along one dimension. - - Parameters - ---------- - center_locations : array-like (units, ) or (channels, units) - Mean of the multivariate gaussian at each channel for each unit. - total_points : array-like - Number of points in each unit distribution. - - Returns - ------- - all_pcs : numpy.ndarray - PC scores for each point. - all_labels : numpy.array - Labels for each point. - """ - from scipy.stats import multivariate_normal - - np.random.seed(0) - - if len(np.array(center_locations).shape) == 1: - distributions = [ - multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) - for center, size in zip(center_locations, total_points) - ] - all_pcs = np.concatenate(distributions, axis=0) - - else: - all_pcs = np.empty((np.sum(total_points), 3, center_locations.shape[0])) - for channel in range(center_locations.shape[0]): - distributions = [ - multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) - for center, size in zip(center_locations[channel], total_points) - ] - all_pcs[:, :, channel] = np.concatenate(distributions, axis=0) - - all_labels = np.concatenate([np.ones((total_points[i],), dtype="int") * i for i in range(len(total_points))]) - - return all_pcs, all_labels From 7a42fe32354b1ffd024a32ec327d2353de2196a0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 13 Jan 2026 16:46:14 +0100 Subject: [PATCH 15/40] rebase on #4316 --- src/spikeinterface/metrics/quality/utils.py | 47 ------- src/spikeinterface/metrics/utils.py | 142 ++++++++++++++++++++ 2 files changed, 142 insertions(+), 47 deletions(-) delete mode 100644 src/spikeinterface/metrics/quality/utils.py create mode 100644 src/spikeinterface/metrics/utils.py diff --git a/src/spikeinterface/metrics/quality/utils.py b/src/spikeinterface/metrics/quality/utils.py deleted file mode 100644 index 844a7da7f5..0000000000 --- a/src/spikeinterface/metrics/quality/utils.py +++ /dev/null @@ -1,47 +0,0 @@ -from __future__ import annotations - -import numpy as np - - -def create_ground_truth_pc_distributions(center_locations, total_points): - """ - Simulate PCs as multivariate Gaussians, for testing PC-based quality metrics - Values are created for only one channel and vary along one dimension. - - Parameters - ---------- - center_locations : array-like (units, ) or (channels, units) - Mean of the multivariate gaussian at each channel for each unit. - total_points : array-like - Number of points in each unit distribution. - - Returns - ------- - all_pcs : numpy.ndarray - PC scores for each point. - all_labels : numpy.array - Labels for each point. - """ - from scipy.stats import multivariate_normal - - np.random.seed(0) - - if len(np.array(center_locations).shape) == 1: - distributions = [ - multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) - for center, size in zip(center_locations, total_points) - ] - all_pcs = np.concatenate(distributions, axis=0) - - else: - all_pcs = np.empty((np.sum(total_points), 3, center_locations.shape[0])) - for channel in range(center_locations.shape[0]): - distributions = [ - multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) - for center, size in zip(center_locations[channel], total_points) - ] - all_pcs[:, :, channel] = np.concatenate(distributions, axis=0) - - all_labels = np.concatenate([np.ones((total_points[i],), dtype="int") * i for i in range(len(total_points))]) - - return all_pcs, all_labels diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py new file mode 100644 index 0000000000..16058e521f --- /dev/null +++ b/src/spikeinterface/metrics/utils.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import numpy as np + + +def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, periods=None): + """ + Compute bin edges for units, optionally taking into account periods. + + Parameters + ---------- + sorting : Sorting + Sorting object containing unit information. + segment_samples : list or array-like + Number of samples in each segment. + bin_duration_s : float, default: 1 + Duration of each bin in seconds + periods : array of unit_period_dtype, default: None + Periods to consider for each unit + """ + bin_edges_for_units = {} + num_segments = len(segment_samples) + bin_duration_samples = int(bin_duration_s * sorting.sampling_frequency) + + if periods is not None: + for unit_id in sorting.unit_ids: + unit_index = sorting.id_to_index(unit_id) + periods_unit = periods[periods["unit_index"] == unit_index] + bin_edges = [] + for seg_index in range(num_segments): + seg_periods = periods_unit[periods_unit["segment_index"] == seg_index] + if len(seg_periods) == 0: + continue + seg_start = np.sum(segment_samples[:seg_index]) + for period in seg_periods: + start_sample = seg_start + period["start_sample_index"] + end_sample = seg_start + period["end_sample_index"] + bin_edges.extend(np.arange(start_sample, end_sample, bin_duration_samples)) + bin_edges_for_units[unit_id] = np.array(bin_edges) + else: + total_length = np.sum(segment_samples) + for unit_id in sorting.unit_ids: + bin_edges_for_units[unit_id] = np.arange(0, total_length, bin_duration_samples) * bin_duration_samples + return bin_edges_for_units + + +def compute_total_samples_per_unit(sorting_analyzer, periods=None): + """ + Get total number of samples for each unit, optionally taking into account periods. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object. + periods : array of unit_period_dtype, default: None + Periods to consider for each unit. + + Returns + ------- + dict + Total number of samples for each unit. + """ + if periods is not None: + total_samples = {} + sorting = sorting_analyzer.sorting + for unit_id in sorting.unit_ids: + unit_index = sorting.id_to_index(unit_id) + periods_unit = periods[periods["unit_index"] == unit_index] + num_samples_in_period = 0 + for period in periods_unit: + num_samples_in_period += period["end_sample_index"] - period["start_sample_index"] + total_samples[unit_id] = num_samples_in_period + else: + total_samples = {unit_id: sorting_analyzer.get_total_samples() for unit_id in sorting_analyzer.unit_ids} + return total_samples + + +def compute_total_durations_per_unit(sorting_analyzer, periods=None): + """ + Compute total duration for each unit, optionally taking into account periods. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object. + periods : array of unit_period_dtype, default: None + Periods to consider for each unit. + + Returns + ------- + dict + Total duration for each unit. + """ + total_samples = compute_total_samples_per_unit(sorting_analyzer, periods=periods) + total_durations = { + unit_id: samples / sorting_analyzer.sorting.sampling_frequency for unit_id, samples in total_samples.items() + } + return total_durations + + +def create_ground_truth_pc_distributions(center_locations, total_points): + """ + Simulate PCs as multivariate Gaussians, for testing PC-based quality metrics + Values are created for only one channel and vary along one dimension. + + Parameters + ---------- + center_locations : array-like (units, ) or (channels, units) + Mean of the multivariate gaussian at each channel for each unit. + total_points : array-like + Number of points in each unit distribution. + + Returns + ------- + all_pcs : numpy.ndarray + PC scores for each point. + all_labels : numpy.array + Labels for each point. + """ + from scipy.stats import multivariate_normal + + np.random.seed(0) + + if len(np.array(center_locations).shape) == 1: + distributions = [ + multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) + for center, size in zip(center_locations, total_points) + ] + all_pcs = np.concatenate(distributions, axis=0) + + else: + all_pcs = np.empty((np.sum(total_points), 3, center_locations.shape[0])) + for channel in range(center_locations.shape[0]): + distributions = [ + multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) + for center, size in zip(center_locations[channel], total_points) + ] + all_pcs[:, :, channel] = np.concatenate(distributions, axis=0) + + all_labels = np.concatenate([np.ones((total_points[i],), dtype="int") * i for i in range(len(total_points))]) + + return all_pcs, all_labels From cbc0986cdfe485c7177e4e84673a3597d0c6dafd Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 14 Jan 2026 09:34:49 +0100 Subject: [PATCH 16/40] Fix import --- src/spikeinterface/core/sorting_tools.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 75e25115ae..f5cf82c76f 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -281,8 +281,8 @@ def select_sorting_periods(sorting: BaseSorting, periods) -> BaseSorting: A new sorting object with only samples between start_sample_index and end_sample_index for the given segment_index. """ + from spikeinterface.core.base import unit_period_dtype from spikeinterface.core.numpyextractors import NumpySorting - from spikeinterface.core.node_pipeline import unit_period_dtype if periods is not None: if not isinstance(periods, np.ndarray): @@ -295,6 +295,7 @@ def select_sorting_periods(sorting: BaseSorting, periods) -> BaseSorting: keep_mask = select_sorting_periods_mask(sorting, periods) sliced_spike_vector = spike_vector[keep_mask] + # important: we keep the original unit ids so the unit_index field in spike vector is still valid sorting = NumpySorting( sliced_spike_vector, sampling_frequency=sorting.sampling_frequency, unit_ids=sorting.unit_ids ) From 046430e13bd48c099e5ed3e756b32df20ba5fd43 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 14 Jan 2026 09:42:45 +0100 Subject: [PATCH 17/40] fix import --- .../metrics/quality/tests/test_metrics_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index fb764dac78..b4f956e6a7 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -12,7 +12,7 @@ synthesize_random_firings, ) -from spikeinterface.metrics.quality.utils import create_ground_truth_pc_distributions +from spikeinterface.metrics.utils import create_ground_truth_pc_distributions # from spikeinterface.metrics.quality_metric_list import ( # _misc_metric_name_to_func, From bb8625358e22bdeeecf07872579dbe56a36a25f3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 14 Jan 2026 09:48:47 +0100 Subject: [PATCH 18/40] Add misc_metric changes --- .../metrics/quality/misc_metrics.py | 368 +++++++++++------- 1 file changed, 231 insertions(+), 137 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index c6b07da52e..2d90493756 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -25,8 +25,12 @@ get_template_extremum_amplitude, get_dense_templates_array, ) - -from ..spiketrain.metrics import NumSpikes, FiringRate +from spikeinterface.metrics.spiketrain.metrics import NumSpikes, FiringRate +from spikeinterface.metrics.utils import ( + compute_bin_edges_per_unit, + compute_total_durations_per_unit, + compute_total_samples_per_unit, +) numba_spec = importlib.util.find_spec("numba") if numba_spec is not None: @@ -35,7 +39,9 @@ HAVE_NUMBA = False -def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0): +def compute_presence_ratios( + sorting_analyzer, unit_ids=None, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, periods=None +): """ Calculate the presence ratio, the fraction of time the unit is firing above a certain threshold. @@ -51,6 +57,9 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 mean_fr_ratio_thresh : float, default: 0 The unit is considered active in a bin if its firing rate during that bin. is strictly above `mean_fr_ratio_thresh` times its mean firing rate throughout the recording. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -63,16 +72,21 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 To do so, spike trains across segments are concatenated to mimic a continuous segment. """ sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() - seg_lengths = [sorting_analyzer.get_num_samples(i) for i in range(num_segs)] - total_length = sorting_analyzer.get_total_samples() - total_duration = sorting_analyzer.get_total_duration() + segment_samples = [sorting_analyzer.get_num_samples(i) for i in range(num_segs)] + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) + total_samples = np.sum(segment_samples) bin_duration_samples = int((bin_duration_s * sorting_analyzer.sampling_frequency)) - num_bin_edges = total_length // bin_duration_samples + 1 - bin_edges = np.arange(num_bin_edges) * bin_duration_samples + bin_edges_per_unit = compute_bin_edges_per_unit( + sorting, + segment_samples=segment_samples, + periods=periods, + bin_duration_s=bin_duration_s, + ) mean_fr_ratio_thresh = float(mean_fr_ratio_thresh) if mean_fr_ratio_thresh < 0: @@ -83,7 +97,7 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 warnings.warn("`mean_fr_ratio_thres` parameter above 1 might lead to low presence ratios.") presence_ratios = {} - if total_length < bin_duration_samples: + if total_samples < bin_duration_samples: warnings.warn( f"Bin duration of {bin_duration_s}s is larger than recording duration. " f"Presence ratios are set to NaN." ) @@ -91,9 +105,15 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 else: for unit_id in unit_ids: spike_train = [] + bin_edges = bin_edges_per_unit[unit_id] + if len(bin_edges) < 2: + presence_ratios[unit_id] = 0.0 + continue + total_duration = total_durations[unit_id] + for segment_index in range(num_segs): st = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - st = st + np.sum(seg_lengths[:segment_index]) + st = st + np.sum(segment_samples[:segment_index]) spike_train.append(st) spike_train = np.concatenate(spike_train) @@ -102,7 +122,6 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 presence_ratios[unit_id] = presence_ratio( spike_train, - total_length, bin_edges=bin_edges, bin_n_spikes_thres=bin_n_spikes_thres, ) @@ -182,7 +201,7 @@ class SNR(BaseMetric): depend_on = ["noise_levels", "templates"] -def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5, min_isi_ms=0): +def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5, min_isi_ms=0, periods=None): """ Calculate Inter-Spike Interval (ISI) violations. @@ -204,6 +223,9 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 Minimum possible inter-spike interval, in ms. This is the artificial refractory period enforced. by the data acquisition system or post-processing algorithms. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -235,11 +257,12 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 res = namedtuple("isi_violation", ["isi_violations_ratio", "isi_violations_count"]) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() - total_duration_s = sorting_analyzer.get_total_duration() + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) fs = sorting_analyzer.sampling_frequency isi_threshold_s = isi_threshold_ms / 1000 @@ -260,7 +283,8 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 if not any([len(train) > 0 for train in spike_train_list]): continue - ratio, _, count = isi_violations(spike_train_list, total_duration_s, isi_threshold_s, min_isi_s) + total_duration = total_durations[unit_id] + ratio, _, count = isi_violations(spike_train_list, total_duration, isi_threshold_s, min_isi_s) isi_violations_ratio[unit_id] = ratio isi_violations_count[unit_id] = count @@ -280,7 +304,7 @@ class ISIViolation(BaseMetric): def compute_refrac_period_violations( - sorting_analyzer, unit_ids=None, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0 + sorting_analyzer, unit_ids=None, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0, periods=None ): """ Calculate the number of refractory period violations. @@ -300,6 +324,9 @@ def compute_refrac_period_violations( censored_period_ms : float, default: 0.0 The period (in ms) where no 2 spikes can occur (because they are not detected, or because they were removed by another mean). + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -322,8 +349,6 @@ def compute_refrac_period_violations( ---------- Based on metrics described in [Llobet]_ """ - from spikeinterface.metrics.spiketrain.metrics import compute_num_spikes - res = namedtuple("rp_violations", ["rp_contamination", "rp_violations"]) if not HAVE_NUMBA: @@ -332,16 +357,18 @@ def compute_refrac_period_violations( return None sorting = sorting_analyzer.sorting - fs = sorting_analyzer.sampling_frequency - num_units = len(sorting_analyzer.unit_ids) - num_segments = sorting_analyzer.get_num_segments() + sorting = sorting.select_periods(periods=periods) + + fs = sorting.sampling_frequency + num_units = len(sorting.unit_ids) + num_segments = sorting.get_num_segments() spikes = sorting.to_spike_vector(concatenated=False) if unit_ids is None: - unit_ids = sorting_analyzer.unit_ids + unit_ids = sorting.unit_ids - num_spikes = compute_num_spikes(sorting_analyzer) + num_spikes = sorting.count_num_spikes_per_unit() t_c = int(round(censored_period_ms * fs * 1e-3)) t_r = int(round(refractory_period_ms * fs * 1e-3)) @@ -352,7 +379,7 @@ def compute_refrac_period_violations( spike_labels = spikes[seg_index]["unit_index"].astype(np.int32) _compute_rp_violations_numba(nb_rp_violations, spike_times, spike_labels, t_c, t_r) - T = sorting_analyzer.get_total_samples() + total_samples = compute_total_samples_per_unit(sorting_analyzer, periods=periods) nb_violations = {} rp_contamination = {} @@ -360,14 +387,15 @@ def compute_refrac_period_violations( for unit_index, unit_id in enumerate(sorting.unit_ids): if unit_id not in unit_ids: continue - - nb_violations[unit_id] = n_v = nb_rp_violations[unit_index] - N = num_spikes[unit_id] - if N == 0: - rp_contamination[unit_id] = np.nan - else: - D = 1 - n_v * (T - 2 * N * t_c) / (N**2 * (t_r - t_c)) - rp_contamination[unit_id] = 1 - math.sqrt(D) if D >= 0 else 1.0 + total_samples_unit = total_samples[unit_id] + nb_violations[unit_id] = nb_rp_violations[unit_index] + rp_contamination[unit_id] = _compute_rp_contamination_one_unit( + nb_rp_violations[unit_index], + num_spikes[unit_id], + total_samples_unit, + t_c, + t_r, + ) return res(rp_contamination, nb_violations) @@ -392,6 +420,7 @@ def compute_sliding_rp_violations( exclude_ref_period_below_ms=0.5, max_ref_period_ms=10, contamination_values=None, + periods=None, ): """ Compute sliding refractory period violations, a metric developed by IBL which computes @@ -417,6 +446,9 @@ def compute_sliding_rp_violations( Maximum refractory period to test in ms. contamination_values : 1d array or None, default: None The contamination values to test, If None, it is set to np.arange(0.5, 35, 0.5). + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -429,8 +461,10 @@ def compute_sliding_rp_violations( This code was adapted from: https://github.com/SteinmetzLab/slidingRefractory/blob/1.0.0/python/slidingRP/metrics.py """ - duration = sorting_analyzer.get_total_duration() + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) + if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() @@ -455,6 +489,7 @@ def compute_sliding_rp_violations( contamination[unit_id] = np.nan continue + duration = total_durations[unit_id] contamination[unit_id] = slidingRP_violations( spike_train_list, fs, @@ -486,7 +521,7 @@ class SlidingRPViolation(BaseMetric): } -def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None): +def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None, periods=None): """ Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of spikes at the exact same sample index, with synchrony sizes 2, 4 and 8. @@ -504,6 +539,9 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N ------- sync_spike_{X} : dict The synchrony metric for synchrony size X. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. References ---------- @@ -520,6 +558,7 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes]) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting.unit_ids @@ -556,7 +595,7 @@ class Synchrony(BaseMetric): } -def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percentiles=(5, 95)): +def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percentiles=(5, 95), periods=None): """ Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution computed in non-overlapping time bins. @@ -571,6 +610,9 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent The size of the bin in seconds. percentiles : tuple, default: (5, 95) The percentiles to compute. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -584,6 +626,11 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent sampling_frequency = sorting_analyzer.sampling_frequency bin_size_samples = int(bin_size_s * sampling_frequency) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) + segment_samples = [ + sorting_analyzer.get_num_samples(segment_index) for segment_index in range(sorting_analyzer.get_num_segments()) + ] + if unit_ids is None: unit_ids = sorting.unit_ids @@ -598,15 +645,25 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent # for each segment, we compute the firing rate histogram and we concatenate them firing_rate_histograms = {unit_id: np.array([], dtype=float) for unit_id in sorting.unit_ids} - for segment_index in range(sorting_analyzer.get_num_segments()): - num_samples = sorting_analyzer.get_num_samples(segment_index) - edges = np.arange(0, num_samples + 1, bin_size_samples) + bin_edges_per_unit = compute_bin_edges_per_unit( + sorting, + segment_samples=segment_samples, + periods=periods, + bin_duration_s=bin_size_s, + ) + for unit_id in unit_ids: + bin_edges = bin_edges_per_unit[unit_id] - for unit_id in unit_ids: - spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - spike_counts, _ = np.histogram(spike_times, bins=edges) - firing_rates = spike_counts / bin_size_s - firing_rate_histograms[unit_id] = np.concatenate((firing_rate_histograms[unit_id], firing_rates)) + # we can concatenate spike trains across segments adding the cumulative number of samples + # as offset, since bin edges are already cumulative + for segment_index in range(sorting_analyzer.get_num_segments()): + st = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + st = st + np.sum(segment_samples[:segment_index]) + spike_train.append(st) + spike_train = np.concatenate(spike_train) + + spike_counts, _ = np.histogram(spike_train, bins=bin_edges) + firing_rate_histograms[unit_id] = spike_counts / bin_size_s # finally we compute the percentiles firing_ranges = {} @@ -635,6 +692,7 @@ def compute_amplitude_cv_metrics( percentiles=(5, 95), min_num_bins=10, amplitude_extension="spike_amplitudes", + periods=None, ): """ Calculate coefficient of variation of spike amplitudes within defined temporal bins. @@ -658,6 +716,8 @@ def compute_amplitude_cv_metrics( the median and range are set to NaN. amplitude_extension : str, default: "spike_amplitudes" The name of the extension to load the amplitudes from. "spike_amplitudes" or "amplitude_scalings". + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) Returns ------- @@ -676,14 +736,15 @@ def compute_amplitude_cv_metrics( "spike_amplitudes", "amplitude_scalings", ), "Invalid amplitude_extension. It can be either 'spike_amplitudes' or 'amplitude_scalings'" - sorting = sorting_analyzer.sorting - total_duration = sorting_analyzer.get_total_duration() - spikes = sorting.to_spike_vector() - num_spikes = sorting.count_num_spikes_per_unit(outputs="dict") if unit_ids is None: unit_ids = sorting.unit_ids + sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) - amps = sorting_analyzer.get_extension(amplitude_extension).get_data() + spikes = sorting.to_spike_vector() + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) + num_spikes = sorting.count_num_spikes_per_unit(outputs="dict") + amps = sorting_analyzer.get_extension(amplitude_extension).get_data(periods=periods) # precompute segment slice segment_slices = [] @@ -695,6 +756,7 @@ def compute_amplitude_cv_metrics( all_unit_ids = list(sorting.unit_ids) amplitude_cv_medians, amplitude_cv_ranges = {}, {} for unit_id in unit_ids: + total_duration = total_durations[unit_id] firing_rate = num_spikes[unit_id] / total_duration temporal_bin_size_samples = int( (average_num_spikes_per_bin / firing_rate) * sorting_analyzer.sampling_frequency @@ -752,6 +814,7 @@ def compute_amplitude_cutoffs( num_histogram_bins=500, histogram_smoothing_value=3, amplitudes_bins_min_ratio=5, + periods=None, ): """ Calculate approximate fraction of spikes missing from a distribution of amplitudes. @@ -770,6 +833,9 @@ def compute_amplitude_cutoffs( The minimum ratio between number of amplitudes for a unit and the number of bins. If the ratio is less than this threshold, the amplitude_cutoff for the unit is set to NaN. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -805,13 +871,12 @@ def compute_amplitude_cutoffs( invert_amplitudes = True extension = sorting_analyzer.get_extension("amplitude_scalings") - amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True) + amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, periods=periods) for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] if invert_amplitudes: amplitudes = -amplitudes - all_fraction_missing[unit_id] = amplitude_cutoff( amplitudes, num_histogram_bins, histogram_smoothing_value, amplitudes_bins_min_ratio ) @@ -837,7 +902,7 @@ class AmplitudeCutoff(BaseMetric): depend_on = ["spike_amplitudes|amplitude_scalings"] -def compute_amplitude_medians(sorting_analyzer, unit_ids=None): +def compute_amplitude_medians(sorting_analyzer, unit_ids=None, periods=None): """ Compute median of the amplitude distributions (in absolute value). @@ -847,6 +912,9 @@ def compute_amplitude_medians(sorting_analyzer, unit_ids=None): A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the amplitude medians. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -865,7 +933,7 @@ def compute_amplitude_medians(sorting_analyzer, unit_ids=None): all_amplitude_medians = {} amplitude_extension = sorting_analyzer.get_extension("spike_amplitudes") - amplitudes_by_units = amplitude_extension.get_data(outputs="by_unit", concatenated=True) + amplitudes_by_units = amplitude_extension.get_data(outputs="by_unit", concatenated=True, periods=periods) for unit_id in unit_ids: all_amplitude_medians[unit_id] = np.median(amplitudes_by_units[unit_id]) @@ -882,7 +950,9 @@ class AmplitudeMedian(BaseMetric): depend_on = ["spike_amplitudes"] -def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, low_quantile=0.1, n_bins=100): +def compute_noise_cutoffs( + sorting_analyzer, unit_ids=None, high_quantile=0.25, low_quantile=0.1, n_bins=100, periods=None +): """ A metric to determine if a unit's amplitude distribution is cut off as it approaches zero, without assuming a Gaussian distribution. @@ -906,6 +976,9 @@ def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, l Quantile of the amplitude range below which values are treated as "low" (e.g. 0.1 = lower 10%), the test region. n_bins: int, default: 100 The number of bins to use to compute the amplitude histogram. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -934,7 +1007,7 @@ def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, l invert_amplitudes = True extension = sorting_analyzer.get_extension("amplitude_scalings") - amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True) + amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, periods=periods) for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] @@ -972,6 +1045,7 @@ def compute_drift_metrics( min_fraction_valid_intervals=0.5, min_num_bins=2, return_positions=False, + periods=None, ): """ Compute drifts metrics using estimated spike locations. @@ -1006,6 +1080,9 @@ def compute_drift_metrics( min_num_bins : int, default: 2 Minimum number of bins required to return a valid metric value. In case there are less bins, the metric values are set to NaN. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. return_positions : bool, default: False If True, median positions are returned (for debugging). @@ -1028,35 +1105,18 @@ def compute_drift_metrics( check_has_required_extensions("drift", sorting_analyzer) res = namedtuple("drift_metrics", ["drift_ptp", "drift_std", "drift_mad"]) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting.unit_ids spike_locations_ext = sorting_analyzer.get_extension("spike_locations") - spike_locations = spike_locations_ext.get_data() - # spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit") - spikes = sorting.to_spike_vector() - spike_locations_by_unit = {} - for unit_id in unit_ids: - unit_index = sorting.id_to_index(unit_id) - # TODO @alessio this is very slow this sjould be done with spike_vector_to_indices() in code - spike_mask = spikes["unit_index"] == unit_index - spike_locations_by_unit[unit_id] = spike_locations[spike_mask] + spike_locations = spike_locations_ext.get_data(periods=periods) + spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit", concatenated=True, periods=periods) - interval_samples = int(interval_s * sorting_analyzer.sampling_frequency) + segment_samples = [sorting_analyzer.get_num_samples(i) for i in range(sorting_analyzer.get_num_segments())] assert direction in spike_locations.dtype.names, ( f"Direction {direction} is invalid. Available directions: " f"{spike_locations.dtype.names}" ) - total_duration = sorting_analyzer.get_total_duration() - if total_duration < min_num_bins * interval_s: - warnings.warn( - "The recording is too short given the specified 'interval_s' and " - "'min_num_bins'. Drift metrics will be set to NaN" - ) - empty_dict = {unit_id: np.nan for unit_id in unit_ids} - if return_positions: - return res(empty_dict, empty_dict, empty_dict), np.nan - else: - return res(empty_dict, empty_dict, empty_dict) # we need drift_ptps = {} @@ -1064,45 +1124,58 @@ def compute_drift_metrics( drift_mads = {} # reference positions are the medians across segments - reference_positions = np.zeros(len(unit_ids)) - for i, unit_id in enumerate(unit_ids): - unit_ind = sorting.id_to_index(unit_id) - reference_positions[i] = np.median(spike_locations_by_unit[unit_id][direction]) + reference_positions = {} + for unit_id in unit_ids: + reference_positions[unit_id] = np.median(spike_locations_by_unit[unit_id][direction]) # now compute median positions and concatenate them over segments - median_position_segments = None + spike_vector = sorting.to_spike_vector() + spike_sample_indices = spike_vector["sample_index"] + # we need to add the cumulative sum of segment samples to have global sample indices + cumulative_segment_samples = np.cumsum([0] + segment_samples[:-1]) for segment_index in range(sorting_analyzer.get_num_segments()): - seg_length = sorting_analyzer.get_num_samples(segment_index) - num_bin_edges = seg_length // interval_samples + 1 - bins = np.arange(num_bin_edges) * interval_samples - spike_vector = sorting.to_spike_vector() - - # retrieve spikes in segment - i0, i1 = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1]) - spikes_in_segment = spike_vector[i0:i1] - spike_locations_in_segment = spike_locations[i0:i1] - - # compute median positions (if less than min_spikes_per_interval, median position is 0) - median_positions = np.nan * np.zeros((len(unit_ids), num_bin_edges - 1)) - for bin_index, (start_frame, end_frame) in enumerate(zip(bins[:-1], bins[1:])): - i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame]) - spikes_in_bin = spikes_in_segment[i0:i1] - spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction] - - for i, unit_id in enumerate(unit_ids): - unit_ind = sorting.id_to_index(unit_id) - mask = spikes_in_bin["unit_index"] == unit_ind - if np.sum(mask) >= min_spikes_per_interval: - median_positions[i, bin_index] = np.median(spike_locations_in_bin[mask]) - if median_position_segments is None: - median_position_segments = median_positions - else: - median_position_segments = np.hstack((median_position_segments, median_positions)) + seg_mask = spike_vector["segment_index"] == segment_index + spike_sample_indices[seg_mask] += cumulative_segment_samples[segment_index] + + bin_edges_for_units = compute_bin_edges_per_unit( + sorting, + segment_samples=segment_samples, + periods=periods, + bin_duration_s=interval_s, + ) - # finally, compute deviations and drifts - position_diffs = median_position_segments - reference_positions[:, None] - for i, unit_id in enumerate(unit_ids): - position_diff = position_diffs[i] + median_positions_per_unit = {} + for unit in unit_ids: + bins = bin_edges_for_units[unit] + num_bins = len(bins) - 1 + if num_bins < min_num_bins: + warnings.warn( + f"Unit {unit} has only {num_bins} bins given the specified 'interval_s' and " + f"'min_num_bins'. Drift metrics will be set to NaN" + ) + drift_ptps[unit] = np.nan + drift_stds[unit] = np.nan + drift_mads[unit] = np.nan + continue + + # bin_edges are global across segments, so we have to use spike_sample_indices, + # since we offseted them to be global + bin_spike_indices = np.searchsorted(spike_sample_indices, bins) + median_positions = np.nan * np.zeros(num_bins) + for bin_index, (i0, i1) in enumerate(zip(bin_spike_indices[:-1], bin_spike_indices[1:])): + spikes_in_bin = spike_vector[i0:i1] + spike_locations_in_bin = spike_locations[i0:i1][direction] + + unit_index = sorting_analyzer.sorting.id_to_index(unit) + mask = spikes_in_bin["unit_index"] == unit_index + if np.sum(mask) >= min_spikes_per_interval: + median_positions[bin_index] = np.median(spike_locations_in_bin[mask]) + else: + median_positions[bin_index] = np.nan + median_positions_per_unit[unit] = median_positions + + # now compute deviations and drifts for this unit + position_diff = median_positions - reference_positions[unit_id] if np.any(np.isnan(position_diff)): # deal with nans: if more than 50% nans --> set to nan if np.sum(np.isnan(position_diff)) > min_fraction_valid_intervals * len(position_diff): @@ -1120,8 +1193,9 @@ def compute_drift_metrics( drift_ptps[unit_id] = ptp_drift drift_stds[unit_id] = std_drift drift_mads[unit_id] = mad_drift + if return_positions: - outs = res(drift_ptps, drift_stds, drift_mads), median_positions + outs = res(drift_ptps, drift_stds, drift_mads), median_positions_per_unit else: outs = res(drift_ptps, drift_stds, drift_mads) return outs @@ -1151,6 +1225,7 @@ def compute_sd_ratio( censored_period_ms: float = 4.0, correct_for_drift: bool = True, correct_for_template_itself: bool = True, + periods=None, **kwargs, ): """ @@ -1173,6 +1248,9 @@ def compute_sd_ratio( correct_for_template_itself : bool, default: True If true, will take into account that the template itself impacts the standard deviation of the noise, and will make a rough estimation of what that impact is (and remove it). + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. **kwargs : dict, default: {} Keyword arguments for computing spike amplitudes and extremum channel. @@ -1189,6 +1267,7 @@ def compute_sd_ratio( job_kwargs = fix_job_kwargs(job_kwargs) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) censored_period = int(round(censored_period_ms * 1e-3 * sorting_analyzer.sampling_frequency)) if unit_ids is None: @@ -1201,7 +1280,7 @@ def compute_sd_ratio( ) return {unit_id: np.nan for unit_id in unit_ids} - spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() + spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data(periods=periods) if not HAVE_NUMBA: warnings.warn( @@ -1330,7 +1409,7 @@ def check_has_required_extensions(metric_name, sorting_analyzer): ### LOW-LEVEL FUNCTIONS ### -def presence_ratio(spike_train, total_length, bin_edges=None, num_bin_edges=None, bin_n_spikes_thres=0): +def presence_ratio(spike_train, bin_edges=None, num_bin_edges=None, bin_n_spikes_thres=0): """ Calculate the presence ratio for a single unit. @@ -1338,8 +1417,6 @@ def presence_ratio(spike_train, total_length, bin_edges=None, num_bin_edges=None ---------- spike_train : np.ndarray Spike times for this unit, in samples. - total_length : int - Total length of the recording in samples. bin_edges : np.array, optional Pre-computed bin edges (mutually exclusive with num_bin_edges). num_bin_edges : int, optional @@ -1569,6 +1646,46 @@ def slidingRP_violations( return min_cont_with_90_confidence +def _compute_rp_contamination_one_unit( + n_v, + n_spikes, + total_samples, + t_c, + t_r, +): + """ + Compute the refractory period contamination for one unit. + + Parameters + ---------- + n_v : int + Number of refractory period violations. + n_spikes : int + Number of spikes for the unit. + total_samples : int + Total number of samples in the recording. + t_c : int + Censored period in samples. + t_r : int + Refractory period in samples. + + Returns + ------- + rp_contamination : float + The refractory period contamination for the unit. + """ + if n_spikes <= 1: + return np.nan + + denom = 1 - n_v * (total_samples - 2 * n_spikes * t_c) / (n_spikes**2 * (t_r - t_c)) + if denom < 0: + return 1.0 + + rp_contamination = 1 - math.sqrt(denom) + + return rp_contamination + + def _compute_violations(obs_viol, firing_rate, spike_count, ref_period_dur, contamination_prop): contamination_rate = firing_rate * contamination_prop expected_viol = contamination_rate * ref_period_dur * 2 * spike_count @@ -1709,29 +1826,6 @@ def _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): return synchrony_counts -def _get_amplitudes_by_units(sorting_analyzer, unit_ids, peak_sign): - # used by compute_amplitude_cutoffs and compute_amplitude_medians - - if (spike_amplitudes_extension := sorting_analyzer.get_extension("spike_amplitudes")) is not None: - return spike_amplitudes_extension.get_data(outputs="by_unit", concatenated=True) - - elif sorting_analyzer.has_extension("waveforms"): - amplitudes_by_units = {} - waveforms_ext = sorting_analyzer.get_extension("waveforms") - before = waveforms_ext.nbefore - extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign) - for unit_id in unit_ids: - waveforms = waveforms_ext.get_waveforms_one_unit(unit_id, force_dense=False) - chan_id = extremum_channels_ids[unit_id] - if sorting_analyzer.is_sparse(): - chan_ind = np.where(sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id] == chan_id)[0] - else: - chan_ind = sorting_analyzer.channel_ids_to_indices([chan_id])[0] - amplitudes_by_units[unit_id] = waveforms[:, before, chan_ind] - - return amplitudes_by_units - - if HAVE_NUMBA: import numba From 50f33f0c8a1e94f7e5fcf7b3aa14058e4a836d4e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 14 Jan 2026 15:44:53 +0100 Subject: [PATCH 19/40] fix tests --- .../metrics/quality/misc_metrics.py | 20 ++++++++++--------- .../metrics/quality/tests/conftest.py | 2 +- .../quality/tests/test_metrics_functions.py | 2 +- src/spikeinterface/metrics/utils.py | 2 +- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 2d90493756..b0791d00d7 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -111,6 +111,7 @@ def compute_presence_ratios( continue total_duration = total_durations[unit_id] + spike_train = [] for segment_index in range(num_segs): st = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) st = st + np.sum(segment_samples[:segment_index]) @@ -656,6 +657,7 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent # we can concatenate spike trains across segments adding the cumulative number of samples # as offset, since bin edges are already cumulative + spike_train = [] for segment_index in range(sorting_analyzer.get_num_segments()): st = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) st = st + np.sum(segment_samples[:segment_index]) @@ -737,7 +739,7 @@ def compute_amplitude_cv_metrics( "amplitude_scalings", ), "Invalid amplitude_extension. It can be either 'spike_amplitudes' or 'amplitude_scalings'" if unit_ids is None: - unit_ids = sorting.unit_ids + unit_ids = sorting_analyzer.unit_ids sorting = sorting_analyzer.sorting sorting = sorting.select_periods(periods=periods) @@ -1145,17 +1147,17 @@ def compute_drift_metrics( ) median_positions_per_unit = {} - for unit in unit_ids: - bins = bin_edges_for_units[unit] + for unit_id in unit_ids: + bins = bin_edges_for_units[unit_id] num_bins = len(bins) - 1 if num_bins < min_num_bins: warnings.warn( - f"Unit {unit} has only {num_bins} bins given the specified 'interval_s' and " + f"Unit {unit_id} has only {num_bins} bins given the specified 'interval_s' and " f"'min_num_bins'. Drift metrics will be set to NaN" ) - drift_ptps[unit] = np.nan - drift_stds[unit] = np.nan - drift_mads[unit] = np.nan + drift_ptps[unit_id] = np.nan + drift_stds[unit_id] = np.nan + drift_mads[unit_id] = np.nan continue # bin_edges are global across segments, so we have to use spike_sample_indices, @@ -1166,13 +1168,13 @@ def compute_drift_metrics( spikes_in_bin = spike_vector[i0:i1] spike_locations_in_bin = spike_locations[i0:i1][direction] - unit_index = sorting_analyzer.sorting.id_to_index(unit) + unit_index = sorting_analyzer.sorting.id_to_index(unit_id) mask = spikes_in_bin["unit_index"] == unit_index if np.sum(mask) >= min_spikes_per_interval: median_positions[bin_index] = np.median(spike_locations_in_bin[mask]) else: median_positions[bin_index] = np.nan - median_positions_per_unit[unit] = median_positions + median_positions_per_unit[unit_id] = median_positions # now compute deviations and drifts for this unit position_diff = median_positions - reference_positions[unit_id] diff --git a/src/spikeinterface/metrics/quality/tests/conftest.py b/src/spikeinterface/metrics/quality/tests/conftest.py index c2a6c6fe82..5313e763c1 100644 --- a/src/spikeinterface/metrics/quality/tests/conftest.py +++ b/src/spikeinterface/metrics/quality/tests/conftest.py @@ -10,7 +10,7 @@ def make_small_analyzer(): recording, sorting = generate_ground_truth_recording( - durations=[2.0], + durations=[10.0], num_units=10, seed=1205, ) diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index b4f956e6a7..8b6e67d119 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -223,7 +223,7 @@ def test_unit_structure_in_output(small_sorting_analyzer): "amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3}, "firing_range": {"bin_size_s": 1}, "isi_violation": {"isi_threshold_ms": 10}, - "drift": {"interval_s": 1, "min_spikes_per_interval": 5}, + "drift": {"interval_s": 1, "min_spikes_per_interval": 5, "min_fraction_valid_intervals": 0.2}, "sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15}, "rp_violation": {"refractory_period_ms": 10.0, "censored_period_ms": 0.0}, } diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py index 16058e521f..91538498aa 100644 --- a/src/spikeinterface/metrics/utils.py +++ b/src/spikeinterface/metrics/utils.py @@ -40,7 +40,7 @@ def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, per else: total_length = np.sum(segment_samples) for unit_id in sorting.unit_ids: - bin_edges_for_units[unit_id] = np.arange(0, total_length, bin_duration_samples) * bin_duration_samples + bin_edges_for_units[unit_id] = np.arange(0, total_length, bin_duration_samples) return bin_edges_for_units From 80bc50fa59071ac9640b73818f8708e0fae41757 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 15 Jan 2026 15:05:37 +0100 Subject: [PATCH 20/40] Change base_period_dtype order and fix select_sorting_periods array input --- src/spikeinterface/core/base.py | 2 +- src/spikeinterface/core/sorting_tools.py | 26 ++++++++++++++++--- .../core/tests/test_basesorting.py | 14 ++++++++-- 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 3505853835..4520c19819 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -42,9 +42,9 @@ minimum_spike_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] base_period_dtype = [ + ("segment_index", "int64"), ("start_sample_index", "int64"), ("end_sample_index", "int64"), - ("segment_index", "int64"), ] unit_period_dtype = base_period_dtype + [ diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index f5cf82c76f..d05cc33869 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -271,9 +271,11 @@ def select_sorting_periods(sorting: BaseSorting, periods) -> BaseSorting: Parameters ---------- - periods : numpy.array of unit_period_dtype + periods : numpy.ndarray Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to restrict the sorting. + on which to restrict the sorting. Periods can be either a numpy array of unit_period_dtype + or an array with (num_periods, 4) shape. In the latter case, the fields are assumed to be + in the order: segment_index, start_sample_index, end_sample_index, unit_index. Returns ------- @@ -286,7 +288,25 @@ def select_sorting_periods(sorting: BaseSorting, periods) -> BaseSorting: if periods is not None: if not isinstance(periods, np.ndarray): - periods = np.array([periods], dtype=unit_period_dtype) + raise ValueError("periods must be a numpy array") + if not periods.dtype == unit_period_dtype: + if periods.ndim != 2 or periods.shape[1] != 4: + raise ValueError( + "If periods is not of dtype unit_period_dtype, it must be a 2D array with shape (num_periods, 4)" + ) + warnings.warn( + "periods is not of dtype unit_period_dtype. Assuming fields are in order: " + "(segment_index, start_sample_index, end_sample_index, unit_index).", + UserWarning, + ) + # convert to structured array + periods_converted = np.empty(periods.shape[0], dtype=unit_period_dtype) + periods_converted["segment_index"] = periods[:, 0] + periods_converted["start_sample_index"] = periods[:, 1] + periods_converted["end_sample_index"] = periods[:, 2] + periods_converted["unit_index"] = periods[:, 3] + periods = periods_converted + required = set(np.dtype(unit_period_dtype).names) if not required.issubset(periods.dtype.names): raise ValueError(f"Period must have the following fields: {required}") diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 963320c2a1..ed1931e87a 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -225,7 +225,7 @@ def test_time_slice(): def test_select_periods(): sampling_frequency = 10_000.0 - duration = 1_000 + duration = 100 num_samples = int(sampling_frequency * duration) num_units = 1000 sorting = generate_sorting( @@ -235,7 +235,7 @@ def test_select_periods(): rng = np.random.default_rng() # number of random periods - n_periods = 10_000 + n_periods = 1_000 # generate random periods segment_indices = rng.integers(0, sorting.get_num_segments(), n_periods) start_samples = rng.integers(0, num_samples, n_periods) @@ -280,6 +280,16 @@ def test_select_periods(): spiketrain_sliced = sliced_sorting.get_unit_spike_train(segment_index=segment_index, unit_id=unit_id) assert len(spiketrain_in_periods) == len(spiketrain_sliced) + # now test with input as numpy array with shape (n_periods, 4) + periods_array = np.zeros((len(periods), 4), dtype="int64") + periods_array[:, 0] = periods["segment_index"] + periods_array[:, 1] = periods["start_sample_index"] + periods_array[:, 2] = periods["end_sample_index"] + periods_array[:, 3] = periods["unit_index"] + + sliced_sorting_array = sorting.select_periods(periods=periods_array) + np.testing.assert_array_equal(sliced_sorting.to_spike_vector(), sliced_sorting_array.to_spike_vector()) + if __name__ == "__main__": test_BaseSorting() From 96e6a5317e1e8a76b77d37e9a56f5e399ad5e1b1 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 15 Jan 2026 17:02:02 +0100 Subject: [PATCH 21/40] fix tests --- src/spikeinterface/metrics/quality/misc_metrics.py | 9 ++++----- src/spikeinterface/metrics/spiketrain/metrics.py | 7 ++++++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 9942c2f707..835465a4c1 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -382,10 +382,10 @@ def compute_refrac_period_violations( for segment_index in range(sorting_analyzer.get_num_segments()): spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - nb_rp_violations[unit_id] += _compute_rp_violations_numba(spike_times, t_c, t_r) + nb_violations[unit_id] += _compute_rp_violations_numba(spike_times, t_c, t_r) rp_contamination[unit_id] = _compute_rp_contamination_one_unit( - nb_rp_violations[unit_id], + nb_violations[unit_id], num_spikes[unit_id], total_samples_unit, t_c, @@ -1122,9 +1122,8 @@ def compute_drift_metrics( # we need to add the cumulative sum of segment samples to have global sample indices cumulative_segment_samples = np.cumsum([0] + segment_samples[:-1]) for segment_index in range(sorting_analyzer.get_num_segments()): - spike_sample_indices[sorting._get_spike_vector_segment_slices()[segment_index]] += cumulative_segment_samples[ - segment_index - ] + segment_slice = sorting._get_spike_vector_segment_slices()[segment_index] + spike_sample_indices[segment_slice[0] : segment_slice[1]] += cumulative_segment_samples[segment_index] bin_edges_for_units = compute_bin_edges_per_unit( sorting, diff --git a/src/spikeinterface/metrics/spiketrain/metrics.py b/src/spikeinterface/metrics/spiketrain/metrics.py index 600ae2e406..669733f47a 100644 --- a/src/spikeinterface/metrics/spiketrain/metrics.py +++ b/src/spikeinterface/metrics/spiketrain/metrics.py @@ -26,7 +26,12 @@ def compute_num_spikes(sorting_analyzer, unit_ids=None, periods=None): sorting = sorting.select_periods(periods) if unit_ids is None: unit_ids = sorting.unit_ids - return sorting.count_num_spikes_per_unit(unit_ids=unit_ids) + # re-order dict to match unit_ids order + count_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) + num_spikes = {} + for unit_id in unit_ids: + num_spikes[unit_id] = count_spikes[unit_id] + return num_spikes class NumSpikes(BaseMetric): From 319891137e9d395862aa26d7b78e54774626922c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 15 Jan 2026 17:38:59 +0100 Subject: [PATCH 22/40] Fix generation of bins --- src/spikeinterface/metrics/quality/misc_metrics.py | 5 +++-- src/spikeinterface/metrics/utils.py | 14 ++++++++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 835465a4c1..85bd507c2a 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -648,6 +648,7 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent periods=periods, bin_duration_s=bin_size_s, ) + cumulative_segment_samples = np.cumsum([0] + segment_samples[:-1]) for unit_id in unit_ids: bin_edges = bin_edges_per_unit[unit_id] @@ -656,9 +657,9 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent spike_trains = [] for segment_index in range(sorting_analyzer.get_num_segments()): spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - spike_times = spike_times + np.sum(segment_samples[:segment_index]) + spike_times = spike_times + cumulative_segment_samples[segment_index] spike_trains.append(spike_times) - spike_train = np.concatenate(spike_trains) + spike_train = np.concatenate(spike_trains, dtype="int64") spike_counts, _ = np.histogram(spike_train, bins=bin_edges) firing_rate_histograms[unit_id] = spike_counts / bin_size_s diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py index 222503a730..83ddfcf90b 100644 --- a/src/spikeinterface/metrics/utils.py +++ b/src/spikeinterface/metrics/utils.py @@ -35,12 +35,22 @@ def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, per for period in seg_periods: start_sample = seg_start + period["start_sample_index"] end_sample = seg_start + period["end_sample_index"] + end_sample = end_sample // bin_duration_samples * bin_duration_samples + 1 # align to bin bin_edges.extend(np.arange(start_sample, end_sample, bin_duration_samples)) bin_edges_for_units[unit_id] = np.array(bin_edges) else: - total_length = np.sum(segment_samples) for unit_id in sorting.unit_ids: - bin_edges_for_units[unit_id] = np.arange(0, total_length, bin_duration_samples) + bin_edges = [] + for seg_index in range(num_segments): + seg_start = np.sum(segment_samples[:seg_index]) + seg_end = seg_start + segment_samples[seg_index] + # for segments which are not the last, we don't need to correct the end + # since the first index of the next segment will be the end of the current segment + if seg_index == num_segments - 1: + seg_end = seg_end // bin_duration_samples * bin_duration_samples + 1 # align to bin + bins = np.arange(seg_start, seg_end, bin_duration_samples) + bin_edges.extend(bins) + bin_edges_for_units[unit_id] = np.array(bin_edges) return bin_edges_for_units From 7446a43187f4434a466a8ae72153d57585833cb5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 16 Jan 2026 14:48:12 +0100 Subject: [PATCH 23/40] Use cached get_spike_vector_to_indices --- src/spikeinterface/core/sorting_tools.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index d05cc33869..bc0a1871af 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -246,9 +246,8 @@ def select_sorting_periods_mask(sorting: BaseSorting, periods): A boolean mask of the spikes in the sorting object, with True for spikes within the specified periods. """ spike_vector = sorting.to_spike_vector() - spike_vector_list = sorting.to_spike_vector(concatenated=False) keep_mask = np.zeros(len(spike_vector), dtype=bool) - all_global_indices = spike_vector_to_indices(spike_vector_list, unit_ids=sorting.unit_ids, absolute_index=True) + all_global_indices = sorting.get_spike_vector_to_indices() for segment_index in range(sorting.get_num_segments()): global_indices_segment = all_global_indices[segment_index] # filter periods by segment From 51e906a5ef93b4f13b32f7a2e5c273f5b7073ae0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 16 Jan 2026 17:43:09 +0100 Subject: [PATCH 24/40] Fix error in merging --- src/spikeinterface/metrics/quality/misc_metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 6b217d197d..8176e07628 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -349,6 +349,8 @@ def compute_refrac_period_violations( """ res = namedtuple("rp_violations", ["rp_contamination", "rp_violations"]) + sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting.unit_ids @@ -357,8 +359,6 @@ def compute_refrac_period_violations( warnings.warn("compute_refrac_period_violations cannot run without numba.") return {unit_id: np.nan for unit_id in unit_ids} - sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods=periods) num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) fs = sorting_analyzer.sampling_frequency From 220951425b4e1f608c0a97fcab8c6002404da343 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 20 Jan 2026 16:06:34 +0100 Subject: [PATCH 25/40] Add supports_periods in BaseMetric/Extension --- .../core/analyzer_extension_core.py | 21 +++- .../metrics/quality/misc_metrics.py | 100 ++++++++++-------- 2 files changed, 73 insertions(+), 48 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 0fe7fc81c1..bc81f063d1 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -823,10 +823,9 @@ class BaseMetric: metric_columns = {} # column names and their dtypes of the dataframe metric_descriptions = {} # descriptions of each metric column needs_recording = False # whether the metric needs recording - needs_tmp_data = ( - False # whether the metric needs temporary data comoputed with _prepare_data at the MetricExtension level - ) - needs_job_kwargs = False + needs_tmp_data = False # whether the metric needs temporary data computed with MetricExtension._prepare_data + needs_job_kwargs = False # whether the metric needs job_kwargs + supports_periods = False # whether the metric function supports periods depend_on = [] # extensions the metric depends on # the metric function must have the signature: @@ -839,7 +838,7 @@ class BaseMetric: metric_function = None # to be defined in subclass @classmethod - def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs): + def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs, periods=None): """Compute the metric. Parameters @@ -854,6 +853,8 @@ def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs Temporary data to pass to the metric function job_kwargs : dict Job keyword arguments to control parallelization + periods : np.ndarray | None + Numpy array of unit periods of unit_period_dtype if supports_periods is True Returns ------- @@ -865,6 +866,8 @@ def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs args += (tmp_data,) if cls.needs_job_kwargs: args += (job_kwargs,) + if cls.supports_periods: + args += (periods,) results = cls.metric_function(*args, **metric_params) @@ -988,6 +991,7 @@ def _set_params( metric_params: dict | None = None, delete_existing_metrics: bool = False, metrics_to_compute: list[str] | None = None, + periods: np.ndarray | None = None, **other_params, ): """ @@ -1004,6 +1008,8 @@ def _set_params( If True, existing metrics in the extension will be deleted before computing new ones. metrics_to_compute : list[str] | None List of metric names to compute. If None, all metrics in `metric_names` are computed. + periods : np.ndarray | None + Numpy array of unit_period_dtype defining periods to compute metrics over. other_params : dict Additional parameters for metric computation. @@ -1079,6 +1085,7 @@ def _set_params( metrics_to_compute=metrics_to_compute, delete_existing_metrics=delete_existing_metrics, metric_params=metric_params, + periods=periods, **other_params, ) return params @@ -1129,6 +1136,8 @@ def _compute_metrics( if metric_names is None: metric_names = self.params["metric_names"] + periods = self.params.get("periods", None) + column_names_dtypes = {} for metric_name in metric_names: metric = [m for m in self.metric_list if m.metric_name == metric_name][0] @@ -1153,6 +1162,7 @@ def _compute_metrics( metric_params=metric_params, tmp_data=tmp_data, job_kwargs=job_kwargs, + periods=periods, ) except Exception as e: warnings.warn(f"Error computing metric {metric_name}: {e}") @@ -1179,6 +1189,7 @@ def _run(self, **job_kwargs): metrics_to_compute = self.params["metrics_to_compute"] delete_existing_metrics = self.params["delete_existing_metrics"] + periods = self.params.get("periods", None) _, job_kwargs = split_job_kwargs(job_kwargs) job_kwargs = fix_job_kwargs(job_kwargs) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 8176e07628..ec0caf8138 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -39,7 +39,7 @@ def compute_presence_ratios( - sorting_analyzer, unit_ids=None, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, periods=None + sorting_analyzer, unit_ids=None, periods=None, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0 ): """ Calculate the presence ratio, the fraction of time the unit is firing above a certain threshold. @@ -50,15 +50,15 @@ def compute_presence_ratios( A SortingAnalyzer object. unit_ids : list or None The list of unit ids to compute the presence ratio. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. bin_duration_s : float, default: 60 The duration of each bin in seconds. If the duration is less than this value, presence_ratio is set to NaN. mean_fr_ratio_thresh : float, default: 0 The unit is considered active in a bin if its firing rate during that bin. is strictly above `mean_fr_ratio_thresh` times its mean firing rate throughout the recording. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -136,6 +136,7 @@ class PresenceRatio(BaseMetric): metric_params = {"bin_duration_s": 60, "mean_fr_ratio_thresh": 0.0} metric_columns = {"presence_ratio": float} metric_descriptions = {"presence_ratio": "Fraction of time the unit is active."} + supports_periods = True def compute_snrs( @@ -199,10 +200,11 @@ class SNR(BaseMetric): metric_params = {"peak_sign": "neg", "peak_mode": "extremum"} metric_columns = {"snr": float} metric_descriptions = {"snr": "Signal to noise ratio for each unit."} + supports_periods = True depend_on = ["noise_levels", "templates"] -def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5, min_isi_ms=0, periods=None): +def compute_isi_violations(sorting_analyzer, unit_ids=None, periods=None, isi_threshold_ms=1.5, min_isi_ms=0): """ Calculate Inter-Spike Interval (ISI) violations. @@ -217,6 +219,9 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 The SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the ISI violations. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. isi_threshold_ms : float, default: 1.5 Threshold for classifying adjacent spikes as an ISI violation, in ms. This is the biophysical refractory period. @@ -224,9 +229,6 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 Minimum possible inter-spike interval, in ms. This is the artificial refractory period enforced. by the data acquisition system or post-processing algorithms. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -299,10 +301,11 @@ class ISIViolation(BaseMetric): "isi_violations_ratio": "Ratio of ISI violations for each unit.", "isi_violations_count": "Count of ISI violations for each unit.", } + supports_periods = True def compute_refrac_period_violations( - sorting_analyzer, unit_ids=None, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0, periods=None + sorting_analyzer, unit_ids=None, periods=None, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0 ): """ Calculate the number of refractory period violations. @@ -317,14 +320,14 @@ def compute_refrac_period_violations( The SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the refractory period violations. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. refractory_period_ms : float, default: 1.0 The period (in ms) where no 2 good spikes can occur. censored_period_ms : float, default: 0.0 The period (in ms) where no 2 spikes can occur (because they are not detected, or because they were removed by another mean). - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -397,18 +400,19 @@ class RPViolation(BaseMetric): "rp_contamination": "Refractory period contamination described in Llobet & Wyngaard 2022.", "rp_violations": "Number of refractory period violations.", } + supports_periods = True def compute_sliding_rp_violations( sorting_analyzer, unit_ids=None, + periods=None, min_spikes=0, bin_size_ms=0.25, window_size_s=1, exclude_ref_period_below_ms=0.5, max_ref_period_ms=10, contamination_values=None, - periods=None, ): """ Compute sliding refractory period violations, a metric developed by IBL which computes @@ -421,6 +425,9 @@ def compute_sliding_rp_violations( A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the sliding RP violations. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. min_spikes : int, default: 0 Contamination is set to np.nan if the unit has less than this many spikes across all segments. @@ -434,9 +441,6 @@ def compute_sliding_rp_violations( Maximum refractory period to test in ms. contamination_values : 1d array or None, default: None The contamination values to test, If None, it is set to np.arange(0.5, 35, 0.5). - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -508,9 +512,10 @@ class SlidingRPViolation(BaseMetric): metric_descriptions = { "sliding_rp_violation": "Minimum contamination at 90% confidence using sliding refractory period method." } + supports_periods = True -def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None, periods=None): +def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, periods=None, synchrony_sizes=None): """ Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of spikes at the exact same sample index, with synchrony sizes 2, 4 and 8. @@ -521,6 +526,9 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N A SortingAnalyzer object. unit_ids : list or None, default: None List of unit ids to compute the synchrony metrics. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. synchrony_sizes: None, default: None Deprecated argument. Please use private `_get_synchrony_counts` if you need finer control over number of synchronous spikes. @@ -528,9 +536,6 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N ------- sync_spike_{X} : dict The synchrony metric for synchrony size X. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. References ---------- @@ -583,9 +588,10 @@ class Synchrony(BaseMetric): "sync_spike_4": "Fraction of spikes that are synchronous with at least three other spikes.", "sync_spike_8": "Fraction of spikes that are synchronous with at least seven other spikes.", } + supports_periods = True -def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percentiles=(5, 95), periods=None): +def compute_firing_ranges(sorting_analyzer, unit_ids=None, periods=None, bin_size_s=5, percentiles=(5, 95)): """ Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution computed in non-overlapping time bins. @@ -596,13 +602,13 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the firing range. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. bin_size_s : float, default: 5 The size of the bin in seconds. percentiles : tuple, default: (5, 95) The percentiles to compute. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -675,16 +681,17 @@ class FiringRange(BaseMetric): metric_descriptions = { "firing_range": "Range between the percentiles (default: 5th and 95th) of the firing rates distribution." } + supports_periods = True def compute_amplitude_cv_metrics( sorting_analyzer, unit_ids=None, + periods=None, average_num_spikes_per_bin=50, percentiles=(5, 95), min_num_bins=10, amplitude_extension="spike_amplitudes", - periods=None, ): """ Calculate coefficient of variation of spike amplitudes within defined temporal bins. @@ -697,6 +704,9 @@ def compute_amplitude_cv_metrics( A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the amplitude spread. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. average_num_spikes_per_bin : int, default: 50 The average number of spikes per bin. This is used to estimate a temporal bin size using the firing rate of each unit. For example, if a unit has a firing rate of 10 Hz, amd the average number of spikes per bin is @@ -708,8 +718,6 @@ def compute_amplitude_cv_metrics( the median and range are set to NaN. amplitude_extension : str, default: "spike_amplitudes" The name of the extension to load the amplitudes from. "spike_amplitudes" or "amplitude_scalings". - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) Returns ------- @@ -786,16 +794,17 @@ class AmplitudeCV(BaseMetric): "amplitude_cv_median": "Median of the coefficient of variation of spike amplitudes within temporal bins.", "amplitude_cv_range": "Range of the coefficient of variation of spike amplitudes within temporal bins.", } + supports_periods = True depend_on = ["spike_amplitudes|amplitude_scalings"] def compute_amplitude_cutoffs( sorting_analyzer, unit_ids=None, + periods=None, num_histogram_bins=500, histogram_smoothing_value=3, amplitudes_bins_min_ratio=5, - periods=None, ): """ Calculate approximate fraction of spikes missing from a distribution of amplitudes. @@ -806,6 +815,9 @@ def compute_amplitude_cutoffs( A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the amplitude cutoffs. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. num_histogram_bins : int, default: 100 The number of bins to use to compute the amplitude histogram. histogram_smoothing_value : int, default: 3 @@ -814,9 +826,6 @@ def compute_amplitude_cutoffs( The minimum ratio between number of amplitudes for a unit and the number of bins. If the ratio is less than this threshold, the amplitude_cutoff for the unit is set to NaN. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -881,6 +890,7 @@ class AmplitudeCutoff(BaseMetric): metric_descriptions = { "amplitude_cutoff": "Estimated fraction of missing spikes, based on the amplitude distribution." } + supports_periods = True depend_on = ["spike_amplitudes|amplitude_scalings"] @@ -929,11 +939,12 @@ class AmplitudeMedian(BaseMetric): metric_descriptions = { "amplitude_median": "Median of the amplitude distributions (in absolute value) for each unit in uV." } + supports_periods = True depend_on = ["spike_amplitudes"] def compute_noise_cutoffs( - sorting_analyzer, unit_ids=None, high_quantile=0.25, low_quantile=0.1, n_bins=100, periods=None + sorting_analyzer, unit_ids=None, periods=None, high_quantile=0.25, low_quantile=0.1, n_bins=100 ): """ A metric to determine if a unit's amplitude distribution is cut off as it approaches zero, without assuming a Gaussian distribution. @@ -952,15 +963,15 @@ def compute_noise_cutoffs( A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the amplitude cutoffs. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. high_quantile : float, default: 0.25 Quantile of the amplitude range above which values are treated as "high" (e.g. 0.25 = top 25%), the reference region. low_quantile : int, default: 0.1 Quantile of the amplitude range below which values are treated as "low" (e.g. 0.1 = lower 10%), the test region. n_bins: int, default: 100 The number of bins to use to compute the amplitude histogram. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -1015,19 +1026,20 @@ class NoiseCutoff(BaseMetric): ), "noise_ratio": "Ratio of counts in the lower-amplitude bins to the count in the highest bin.", } + supports_periods = True depend_on = ["spike_amplitudes|amplitude_scalings"] def compute_drift_metrics( sorting_analyzer, unit_ids=None, + periods=None, interval_s=60, min_spikes_per_interval=100, direction="y", min_fraction_valid_intervals=0.5, min_num_bins=2, return_positions=False, - periods=None, ): """ Compute drifts metrics using estimated spike locations. @@ -1049,6 +1061,9 @@ def compute_drift_metrics( A SortingAnalyzer object. unit_ids : list or None, default: None List of unit ids to compute the drift metrics. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. interval_s : int, default: 60 Interval length is seconds for computing spike depth. min_spikes_per_interval : int, default: 100 @@ -1062,9 +1077,6 @@ def compute_drift_metrics( min_num_bins : int, default: 2 Minimum number of bins required to return a valid metric value. In case there are less bins, the metric values are set to NaN. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. return_positions : bool, default: False If True, median positions are returned (for debugging). @@ -1198,16 +1210,17 @@ class Drift(BaseMetric): "drift_std": "Standard deviation of the drift signal in um.", "drift_mad": "Median absolute deviation of the drift signal in um.", } + supports_periods = True depend_on = ["spike_locations"] def compute_sd_ratio( sorting_analyzer: SortingAnalyzer, unit_ids=None, + periods=None, censored_period_ms: float = 4.0, correct_for_drift: bool = True, correct_for_template_itself: bool = True, - periods=None, **kwargs, ): """ @@ -1223,6 +1236,9 @@ def compute_sd_ratio( A SortingAnalyzer object. unit_ids : list or None, default: None The list of unit ids to compute this metric. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. censored_period_ms : float, default: 4.0 The censored period in milliseconds. This is to remove any potential bursts that could affect the SD. correct_for_drift : bool, default: True @@ -1230,9 +1246,6 @@ def compute_sd_ratio( correct_for_template_itself : bool, default: True If true, will take into account that the template itself impacts the standard deviation of the noise, and will make a rough estimation of what that impact is (and remove it). - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. **kwargs : dict, default: {} Keyword arguments for computing spike amplitudes and extremum channel. @@ -1346,6 +1359,7 @@ class SDRatio(BaseMetric): "sd_ratio": "Ratio between the standard deviation of spike amplitudes and the standard deviation of noise." } needs_recording = True + supports_periods = True depend_on = ["templates", "spike_amplitudes"] From b23c431bc595650b336821811bb87bd3b8026424 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 20 Jan 2026 16:56:16 +0100 Subject: [PATCH 26/40] wip: test metrics with periods --- .../core/analyzer_extension_core.py | 14 +- .../metrics/quality/misc_metrics.py | 1 - .../quality/tests/test_metrics_functions.py | 687 ++++++++++-------- 3 files changed, 397 insertions(+), 305 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index bc81f063d1..a21404e58f 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -13,6 +13,7 @@ import numpy as np from collections import namedtuple +from .numpyextractors import NumpySorting from .sortinganalyzer import SortingAnalyzer, AnalyzerExtension, register_result_extension from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_with_accumulator from .recording_tools import get_noise_levels @@ -1463,6 +1464,16 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, periods, ) all_data = all_data[keep_mask] + # since we have the mask already, we can use it directly to avoid double computation + spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=True) + sliced_spike_vector = spike_vector[keep_mask] + sorting = NumpySorting( + sliced_spike_vector, + sampling_frequency=self.sorting_analyzer.sampling_frequency, + unit_ids=self.sorting_analyzer.unit_ids, + ) + else: + sorting = self.sorting_analyzer.sorting if outputs == "numpy": if copy: @@ -1474,8 +1485,7 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, if keep_mask is not None: # since we are filtering spikes, we need to recompute the spike indices - spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) - spike_vector = spike_vector[keep_mask] + spike_vector = sorting.to_spike_vector(concatenated=False) spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) else: # use the cache of indices diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index ec0caf8138..04d451202d 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -200,7 +200,6 @@ class SNR(BaseMetric): metric_params = {"peak_sign": "neg", "peak_mode": "extremum"} metric_columns = {"snr": float} metric_descriptions = {"snr": "Signal to noise ratio for each unit."} - supports_periods = True depend_on = ["noise_levels", "templates"] diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index 99e2e5606a..2e31c53135 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -1,5 +1,4 @@ import pytest -from pathlib import Path import numpy as np from copy import deepcopy import csv @@ -14,13 +13,8 @@ from spikeinterface.metrics.utils import create_ground_truth_pc_distributions -# from spikeinterface.metrics.quality_metric_list import ( -# _misc_metric_name_to_func, -# ) - from spikeinterface.metrics.quality import ( get_quality_metric_list, - get_quality_pca_metric_list, compute_quality_metrics, ) from spikeinterface.metrics.quality.misc_metrics import ( @@ -28,8 +22,6 @@ compute_amplitude_cutoffs, compute_presence_ratios, compute_isi_violations, - # compute_firing_rates, - # compute_num_spikes, compute_snrs, compute_refrac_period_violations, compute_sliding_rp_violations, @@ -44,7 +36,6 @@ ) from spikeinterface.metrics.quality.pca_metrics import ( - pca_metrics_list, mahalanobis_metrics, d_prime_metric, nearest_neighbors_metrics, @@ -53,258 +44,10 @@ ) -from spikeinterface.core.base import minimum_spike_dtype - - -job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") - - -def test_noise_cutoff(): - """ - Generate two artifical gaussian, one truncated and one not. Check the metrics are higher for the truncated one. - """ - np.random.seed(1) - amps = np.random.normal(0, 1, 1000) - amps_trunc = amps[amps > -1] - - cutoff1, ratio1 = _noise_cutoff(amps=amps) - cutoff2, ratio2 = _noise_cutoff(amps=amps_trunc) - - assert cutoff1 <= cutoff2 - assert ratio1 <= ratio2 - - -def test_compute_new_quality_metrics(small_sorting_analyzer): - """ - Computes quality metrics then computes a subset of quality metrics, and checks - that the old quality metrics are not deleted. - """ - - qm_params = { - "presence_ratio": {"bin_duration_s": 0.1}, - "amplitude_cutoff": {"num_histogram_bins": 3}, - "firing_range": {"bin_size_s": 1}, - } - - small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) - qm_extension = small_sorting_analyzer.get_extension("quality_metrics") - calculated_metrics = list(qm_extension.get_data().keys()) - - assert calculated_metrics == ["snr"] - - small_sorting_analyzer.compute( - {"quality_metrics": {"metric_names": list(qm_params.keys()), "metric_params": qm_params}} - ) - small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) - - quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") - - # Check old metrics are not deleted and the new one is added to the data and metadata - assert set(list(quality_metric_extension.get_data().keys())) == set( - [ - "amplitude_cutoff", - "firing_range", - "presence_ratio", - "snr", - ] - ) - assert set(list(quality_metric_extension.params.get("metric_names"))) == set( - [ - "amplitude_cutoff", - "firing_range", - "presence_ratio", - "snr", - ] - ) - - # check that, when parameters are changed, the data and metadata are updated - old_snr_data = deepcopy(quality_metric_extension.get_data()["snr"].values) - small_sorting_analyzer.compute( - {"quality_metrics": {"metric_names": ["snr"], "metric_params": {"snr": {"peak_mode": "peak_to_peak"}}}} - ) - new_quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") - new_snr_data = new_quality_metric_extension.get_data()["snr"].values - - assert np.all(old_snr_data != new_snr_data) - assert new_quality_metric_extension.params["metric_params"]["snr"]["peak_mode"] == "peak_to_peak" - - -def test_metric_names_in_same_order(small_sorting_analyzer): - """ - Computes sepecified quality metrics and checks order is propagated. - """ - specified_metric_names = ["firing_range", "snr", "amplitude_cutoff"] - small_sorting_analyzer.compute("quality_metrics", metric_names=specified_metric_names) - qm_keys = small_sorting_analyzer.get_extension("quality_metrics").get_data().keys() - for i in range(3): - assert specified_metric_names[i] == qm_keys[i] - - -def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder): - """ - Computes quality metrics in binary folder format. Then computes subsets of quality - metrics and checks if they are saved correctly. - """ - - # can't use _misc_metric_name_to_func as some functions compute several qms - # e.g. isi_violation and synchrony - quality_metrics = [ - "num_spikes", - "firing_rate", - "presence_ratio", - "snr", - "isi_violations_ratio", - "isi_violations_count", - "rp_contamination", - "rp_violations", - "sliding_rp_violation", - "amplitude_cutoff", - "amplitude_median", - "amplitude_cv_median", - "amplitude_cv_range", - "sync_spike_2", - "sync_spike_4", - "sync_spike_8", - "firing_range", - "drift_ptp", - "drift_std", - "drift_mad", - "sd_ratio", - "isolation_distance", - "l_ratio", - "d_prime", - "silhouette", - "nn_hit_rate", - "nn_miss_rate", - ] - - small_sorting_analyzer.compute("quality_metrics") - - cache_folder = create_cache_folder - output_folder = cache_folder / "sorting_analyzer" - - folder_analyzer = small_sorting_analyzer.save_as(format="binary_folder", folder=output_folder) - quality_metrics_filename = output_folder / "extensions" / "quality_metrics" / "metrics.csv" - - with open(quality_metrics_filename) as metrics_file: - saved_metrics = csv.reader(metrics_file) - metric_names = next(saved_metrics) - - for metric_name in quality_metrics: - assert metric_name in metric_names - - folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=False) - - with open(quality_metrics_filename) as metrics_file: - saved_metrics = csv.reader(metrics_file) - metric_names = next(saved_metrics) - - for metric_name in quality_metrics: - assert metric_name in metric_names - - folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=True) - - with open(quality_metrics_filename) as metrics_file: - saved_metrics = csv.reader(metrics_file) - metric_names = next(saved_metrics) - - for metric_name in quality_metrics: - if metric_name == "snr": - assert metric_name in metric_names - else: - assert metric_name not in metric_names - - -def test_unit_structure_in_output(small_sorting_analyzer): - - qm_params = { - "presence_ratio": {"bin_duration_s": 0.1}, - "amplitude_cutoff": {"num_histogram_bins": 3}, - "amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3}, - "firing_range": {"bin_size_s": 1}, - "isi_violation": {"isi_threshold_ms": 10}, - "drift": {"interval_s": 1, "min_spikes_per_interval": 5, "min_fraction_valid_intervals": 0.2}, - "sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15}, - "rp_violation": {"refractory_period_ms": 10.0, "censored_period_ms": 0.0}, - } - - for metric in misc_metrics_list: - metric_name = metric.metric_name - metric_fun = metric.metric_function - try: - qm_param = qm_params[metric_name] - except: - qm_param = {} - - result_all = metric_fun(sorting_analyzer=small_sorting_analyzer, **qm_param) - result_sub = metric_fun(sorting_analyzer=small_sorting_analyzer, unit_ids=["#4", "#9"], **qm_param) - - error = "Problem with metric: " + metric_name - - if isinstance(result_all, dict): - assert list(result_all.keys()) == ["#3", "#9", "#4"], error - assert list(result_sub.keys()) == ["#4", "#9"], error - assert result_sub["#9"] == result_all["#9"], error - assert result_sub["#4"] == result_all["#4"], error - - else: - for result_ind, result in enumerate(result_sub): - - assert list(result_all[result_ind].keys()) == ["#3", "#9", "#4"], error - assert result_sub[result_ind].keys() == set(["#4", "#9"]), error - - assert result_sub[result_ind]["#9"] == result_all[result_ind]["#9"], error - assert result_sub[result_ind]["#4"] == result_all[result_ind]["#4"], error - - -def test_unit_id_order_independence(small_sorting_analyzer): - """ - Takes two almost-identical sorting_analyzers, whose unit_ids are in different orders and have different labels, - and checks that their calculated quality metrics are independent of the ordering and labelling. - """ - - recording = small_sorting_analyzer.recording - sorting = small_sorting_analyzer.sorting.select_units(["#4", "#9", "#3"], [1, 7, 2]) - - small_sorting_analyzer_2 = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") - - extensions_to_compute = { - "random_spikes": {"seed": 1205}, - "noise_levels": {"seed": 1205}, - "waveforms": {}, - "templates": {}, - "spike_amplitudes": {}, - "spike_locations": {}, - "principal_components": {}, - } - - small_sorting_analyzer_2.compute(extensions_to_compute) - - # need special params to get non-nan results on a short recording - qm_params = { - "presence_ratio": {"bin_duration_s": 0.1}, - "amplitude_cutoff": {"num_histogram_bins": 3}, - "amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3}, - "firing_range": {"bin_size_s": 1}, - "isi_violation": {"isi_threshold_ms": 10}, - "drift": {"interval_s": 1, "min_spikes_per_interval": 5}, - "sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15}, - } - - quality_metrics_1 = compute_quality_metrics( - small_sorting_analyzer, metric_names=get_quality_metric_list(), metric_params=qm_params, skip_pc_metrics=True - ) - quality_metrics_2 = compute_quality_metrics( - small_sorting_analyzer_2, metric_names=get_quality_metric_list(), metric_params=qm_params, skip_pc_metrics=True - ) - - for metric, metric_2_data in quality_metrics_2.items(): - error = "Problem with the metric " + metric - assert quality_metrics_1[metric]["#3"] == metric_2_data[2], error - assert quality_metrics_1[metric]["#9"] == metric_2_data[7], error - assert quality_metrics_1[metric]["#4"] == metric_2_data[1], error +from spikeinterface.core.base import minimum_spike_dtype, unit_period_dtype +### HELPER FUNCTIONS AND FIXTURES ### def _sorting_violation(): max_time = 100.0 sampling_frequency = 30000 @@ -335,7 +78,6 @@ def _sorting_violation(): def _sorting_analyzer_violations(): - sorting = _sorting_violation() duration = (sorting.to_spike_vector()["sample_index"][-1] + 1) / sorting.sampling_frequency @@ -352,9 +94,87 @@ def _sorting_analyzer_violations(): return sorting_analyzer -@pytest.fixture(scope="module") -def sorting_analyzer_violations(): - return _sorting_analyzer_violations() +@pytest.fixture(scope="module") +def sorting_analyzer_violations(): + return _sorting_analyzer_violations() + + +def compute_periods(sorting_analyzer, num_periods, bin_size_s=None): + """ + Computes and sets periods for each unit in the sorting analyzer. + The periods span the total duration of the recording, but divide it into + smaller periods either by specifying the number of periods or the size of each bin. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer containing the units and recording information. + num_periods : int + The number of periods to divide the total duration into (used if bin_size_s is None). + bin_size_s : float, defaut: None + If given, periods will be multiple of this size in seconds. + + Returns + ------- + periods + np.ndarray of dtype unit_period_dtype containing the segment, start, end samples and unit index. + """ + all_periods = [] + for segment_index in range(sorting_analyzer.recording.get_num_segments()): + samples_per_period = sorting_analyzer.get_num_samples(segment_index) // num_periods + if bin_size_s is not None: + print(f"Original samples_per_period: {samples_per_period} - num_periods: {num_periods}") + bin_size_samples = int(bin_size_s * sorting_analyzer.sampling_frequency) + print(samples_per_period / bin_size_samples) + samples_per_period = samples_per_period // bin_size_samples * bin_size_samples + num_periods = int(np.round(sorting_analyzer.get_num_samples(segment_index) / samples_per_period)) + print(f"Adjusted samples_per_period: {samples_per_period} - num_periods: {num_periods}") + for unit_index, unit_id in enumerate(sorting_analyzer.unit_ids): + period_starts = np.arange(0, sorting_analyzer.get_num_samples(segment_index), samples_per_period) + periods_per_unit = np.zeros(len(period_starts), dtype=unit_period_dtype) + for i, period_start in enumerate(period_starts): + period_end = min(period_start + samples_per_period, sorting_analyzer.get_num_samples(segment_index)) + periods_per_unit[i]["segment_index"] = segment_index + periods_per_unit[i]["start_sample_index"] = period_start + periods_per_unit[i]["end_sample_index"] = period_end + periods_per_unit[i]["unit_index"] = unit_index + print(periods_per_unit, sorting_analyzer.get_num_samples(segment_index), samples_per_period) + all_periods.append(periods_per_unit) + return np.concatenate(all_periods) + + +@pytest.fixture +def periods_simple(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + periods = compute_periods(sorting_analyzer, num_periods=5) + return periods + + +@pytest.fixture +def periods_violations(sorting_analyzer_violations): + sorting_analyzer = sorting_analyzer_violations + periods = compute_periods(sorting_analyzer, num_periods=5) + return periods + + +# Common job kwargs +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") + + +### LOW-LEVEL TESTS ### +def test_noise_cutoff(): + """ + Generate two artifical gaussian, one truncated and one not. Check the metrics are higher for the truncated one. + """ + np.random.seed(1) + amps = np.random.normal(0, 1, 1000) + amps_trunc = amps[amps > -1] + + cutoff1, ratio1 = _noise_cutoff(amps=amps) + cutoff2, ratio2 = _noise_cutoff(amps=amps_trunc) + + assert cutoff1 <= cutoff2 + assert ratio1 <= ratio2 def test_synchrony_counts_no_sync(): @@ -489,22 +309,13 @@ def test_simplified_silhouette_score_metrics(): assert sim_sil_score1 < sim_sil_score2 -# def test_calculate_firing_rate_num_spikes(sorting_analyzer_simple): -# sorting_analyzer = sorting_analyzer_simple -# firing_rates = compute_firing_rates(sorting_analyzer) -# num_spikes = compute_num_spikes(sorting_analyzer) - -# testing method accuracy with magic number is not a good pratcice, I remove this. -# firing_rates_gt = {0: 10.01, 1: 5.03, 2: 5.09} -# num_spikes_gt = {0: 1001, 1: 503, 2: 509} -# assert np.allclose(list(firing_rates_gt.values()), list(firing_rates.values()), rtol=0.05) -# np.testing.assert_array_equal(list(num_spikes_gt.values()), list(num_spikes.values())) - - +### TEST METRICS FUNCTIONS ### def test_calculate_firing_range(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple - firing_ranges = compute_firing_ranges(sorting_analyzer) - print(firing_ranges) + firing_ranges = compute_firing_ranges(sorting_analyzer, bin_size_s=1) + periods = compute_periods(sorting_analyzer, num_periods=5, bin_size_s=1) + firing_ranges_periods = compute_firing_ranges(sorting_analyzer, periods=periods, bin_size_s=1) + assert firing_ranges == firing_ranges_periods with pytest.warns(UserWarning) as w: firing_ranges_nan = compute_firing_ranges( @@ -517,6 +328,9 @@ def test_calculate_amplitude_cutoff(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple # spike_amps = sorting_analyzer.get_extension("spike_amplitudes").get_data() amp_cuts = compute_amplitude_cutoffs(sorting_analyzer, num_histogram_bins=10) + periods = compute_periods(sorting_analyzer, num_periods=5) + amp_cuts_periods = compute_amplitude_cutoffs(sorting_analyzer, periods=periods, num_histogram_bins=10) + assert amp_cuts == amp_cuts_periods # print(amp_cuts) # testing method accuracy with magic number is not a good pratcice, I remove this. @@ -528,18 +342,26 @@ def test_calculate_amplitude_median(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple # spike_amps = sorting_analyzer.get_extension("spike_amplitudes").get_data() amp_medians = compute_amplitude_medians(sorting_analyzer) - # print(amp_medians) + periods = compute_periods(sorting_analyzer, num_periods=5) + amp_medians_periods = compute_amplitude_medians(sorting_analyzer, periods=periods) + assert amp_medians == amp_medians_periods # testing method accuracy with magic number is not a good pratcice, I remove this. # amp_medians_gt = {0: 130.77323354628675, 1: 130.7461997791725, 2: 130.7461997791725} # assert np.allclose(list(amp_medians_gt.values()), list(amp_medians.values()), rtol=0.05) -def test_calculate_amplitude_cv_metrics(sorting_analyzer_simple): +def test_calculate_amplitude_cv_metrics(sorting_analyzer_simple, periods_simple): sorting_analyzer = sorting_analyzer_simple amp_cv_median, amp_cv_range = compute_amplitude_cv_metrics(sorting_analyzer, average_num_spikes_per_bin=20) - print(amp_cv_median) - print(amp_cv_range) + periods = periods_simple + amp_cv_median_periods, amp_cv_range_periods = compute_amplitude_cv_metrics( + sorting_analyzer, + periods=periods, + average_num_spikes_per_bin=20, + ) + assert amp_cv_median == amp_cv_median_periods + assert amp_cv_range == amp_cv_range_periods # amps_scalings = compute_amplitude_scalings(sorting_analyzer) sorting_analyzer.compute("amplitude_scalings", **job_kwargs) @@ -549,34 +371,46 @@ def test_calculate_amplitude_cv_metrics(sorting_analyzer_simple): amplitude_extension="amplitude_scalings", min_num_bins=5, ) - print(amp_cv_median_scalings) - print(amp_cv_range_scalings) + amp_cv_median_scalings_periods, amp_cv_range_scalings_periods = compute_amplitude_cv_metrics( + sorting_analyzer, + periods=periods, + average_num_spikes_per_bin=20, + amplitude_extension="amplitude_scalings", + min_num_bins=5, + ) + assert amp_cv_median_scalings == amp_cv_median_scalings_periods + assert amp_cv_range_scalings == amp_cv_range_scalings_periods -def test_calculate_snrs(sorting_analyzer_simple): +def test_calculate_snrs(sorting_analyzer_simple, periods_simple): sorting_analyzer = sorting_analyzer_simple snrs = compute_snrs(sorting_analyzer) - print(snrs) + # SNR doesn't support periods # testing method accuracy with magic number is not a good pratcice, I remove this. # snrs_gt = {0: 12.92, 1: 12.99, 2: 12.99} # assert np.allclose(list(snrs_gt.values()), list(snrs.values()), rtol=0.05) -def test_calculate_presence_ratio(sorting_analyzer_simple): +def test_calculate_presence_ratio(sorting_analyzer_simple, periods_simple): sorting_analyzer = sorting_analyzer_simple ratios = compute_presence_ratios(sorting_analyzer, bin_duration_s=10) - print(ratios) - + periods = periods_simple + ratios_periods = compute_presence_ratios(sorting_analyzer, periods=periods, bin_duration_s=10) + assert ratios == ratios_periods # testing method accuracy with magic number is not a good pratcice, I remove this. # ratios_gt = {0: 1.0, 1: 1.0, 2: 1.0} # np.testing.assert_array_equal(list(ratios_gt.values()), list(ratios.values())) -def test_calculate_isi_violations(sorting_analyzer_violations): +def test_calculate_isi_violations(sorting_analyzer_violations, periods_violations): sorting_analyzer = sorting_analyzer_violations isi_viol, counts = compute_isi_violations(sorting_analyzer, isi_threshold_ms=1, min_isi_ms=0.0) - print(isi_viol) + periods = periods_violations + isi_viol_periods, counts_periods = compute_isi_violations( + sorting_analyzer, isi_threshold_ms=1, min_isi_ms=0.0, periods=periods + ) + assert isi_viol == isi_viol_periods # testing method accuracy with magic number is not a good pratcice, I remove this. # isi_viol_gt = {0: 0.0998002996004994, 1: 0.7904857139469347, 2: 1.929898371551754} @@ -585,23 +419,30 @@ def test_calculate_isi_violations(sorting_analyzer_violations): # np.testing.assert_array_equal(list(counts_gt.values()), list(counts.values())) -def test_calculate_sliding_rp_violations(sorting_analyzer_violations): +def test_calculate_sliding_rp_violations(sorting_analyzer_violations, periods_violations): sorting_analyzer = sorting_analyzer_violations contaminations = compute_sliding_rp_violations(sorting_analyzer, bin_size_ms=0.25, window_size_s=1) - print(contaminations) + periods = periods_violations + contaminations_periods = compute_sliding_rp_violations( + sorting_analyzer, periods=periods, bin_size_ms=0.25, window_size_s=1 + ) + assert contaminations == contaminations_periods # testing method accuracy with magic number is not a good pratcice, I remove this. # contaminations_gt = {0: 0.03, 1: 0.185, 2: 0.325} # assert np.allclose(list(contaminations_gt.values()), list(contaminations.values()), rtol=0.05) -def test_calculate_rp_violations(sorting_analyzer_violations): +def test_calculate_rp_violations(sorting_analyzer_violations, periods_violations): sorting_analyzer = sorting_analyzer_violations rp_contamination, counts = compute_refrac_period_violations( sorting_analyzer, refractory_period_ms=1, censored_period_ms=0.0 ) - print(rp_contamination, counts) - + periods = periods_violations + rp_contamination_periods, counts_periods = compute_refrac_period_violations( + sorting_analyzer, refractory_period_ms=1, censored_period_ms=0.0, periods=periods + ) + assert rp_contamination == rp_contamination_periods # testing method accuracy with magic number is not a good pratcice, I remove this. # counts_gt = {0: 2, 1: 4, 2: 10} # rp_contamination_gt = {0: 0.10534956502609294, 1: 1.0, 2: 1.0} @@ -620,10 +461,13 @@ def test_calculate_rp_violations(sorting_analyzer_violations): assert np.isnan(rp_contamination[1]) -def test_synchrony_metrics(sorting_analyzer_simple): +def test_synchrony_metrics(sorting_analyzer_simple, periods_simple): sorting_analyzer = sorting_analyzer_simple sorting = sorting_analyzer.sorting synchrony_metrics = compute_synchrony_metrics(sorting_analyzer) + periods = periods_simple + synchrony_metrics_periods = compute_synchrony_metrics(sorting_analyzer, periods=periods) + assert synchrony_metrics == synchrony_metrics_periods synchrony_sizes = np.array([2, 4, 8]) @@ -679,6 +523,13 @@ def test_calculate_drift_metrics(sorting_analyzer_simple): drifts_ptps, drifts_stds, drift_mads = compute_drift_metrics( sorting_analyzer, interval_s=10, min_spikes_per_interval=10 ) + periods = compute_periods(sorting_analyzer, num_periods=5, bin_size_s=10) + drifts_ptps_periods, drifts_stds_periods, drift_mads_periods = compute_drift_metrics( + sorting_analyzer, periods=periods, min_spikes_per_interval=10, interval_s=10 + ) + assert drifts_ptps == drifts_ptps_periods + assert drifts_stds == drifts_stds_periods + assert drift_mads == drift_mads_periods # print(drifts_ptps, drifts_stds, drift_mads) @@ -691,25 +542,257 @@ def test_calculate_drift_metrics(sorting_analyzer_simple): # assert np.allclose(list(drift_mads_gt.values()), list(drift_mads.values()), rtol=0.05) -def test_calculate_sd_ratio(sorting_analyzer_simple): +def test_calculate_sd_ratio(sorting_analyzer_simple, periods_simple): sd_ratio = compute_sd_ratio( sorting_analyzer_simple, ) + periods = periods_simple + sd_ratio_periods = compute_sd_ratio(sorting_analyzer_simple, periods=periods) + assert sd_ratio == sd_ratio_periods assert np.all(list(sd_ratio.keys()) == sorting_analyzer_simple.unit_ids) # @aurelien can you check this, this is not working anymore # assert np.allclose(list(sd_ratio.values()), 1, atol=0.25, rtol=0) -if __name__ == "__main__": +### MACHINERY TESTS ### +def test_compute_new_quality_metrics(small_sorting_analyzer): + """ + Computes quality metrics then computes a subset of quality metrics, and checks + that the old quality metrics are not deleted. + """ + + qm_params = { + "presence_ratio": {"bin_duration_s": 0.1}, + "amplitude_cutoff": {"num_histogram_bins": 3}, + "firing_range": {"bin_size_s": 1}, + } + + small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) + qm_extension = small_sorting_analyzer.get_extension("quality_metrics") + calculated_metrics = list(qm_extension.get_data().keys()) - sorting_analyzer = _sorting_analyzer_simple() - print(sorting_analyzer) + assert calculated_metrics == ["snr"] - test_unit_structure_in_output(_small_sorting_analyzer()) + small_sorting_analyzer.compute( + {"quality_metrics": {"metric_names": list(qm_params.keys()), "metric_params": qm_params}} + ) + small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) - # test_calculate_firing_rate_num_spikes(sorting_analyzer) + quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") + + # Check old metrics are not deleted and the new one is added to the data and metadata + assert set(list(quality_metric_extension.get_data().keys())) == set( + [ + "amplitude_cutoff", + "firing_range", + "presence_ratio", + "snr", + ] + ) + assert set(list(quality_metric_extension.params.get("metric_names"))) == set( + [ + "amplitude_cutoff", + "firing_range", + "presence_ratio", + "snr", + ] + ) + + # check that, when parameters are changed, the data and metadata are updated + old_snr_data = deepcopy(quality_metric_extension.get_data()["snr"].values) + small_sorting_analyzer.compute( + {"quality_metrics": {"metric_names": ["snr"], "metric_params": {"snr": {"peak_mode": "peak_to_peak"}}}} + ) + new_quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") + new_snr_data = new_quality_metric_extension.get_data()["snr"].values + + assert np.all(old_snr_data != new_snr_data) + assert new_quality_metric_extension.params["metric_params"]["snr"]["peak_mode"] == "peak_to_peak" + + +def test_metric_names_in_same_order(small_sorting_analyzer): + """ + Computes sepecified quality metrics and checks order is propagated. + """ + specified_metric_names = ["firing_range", "snr", "amplitude_cutoff"] + small_sorting_analyzer.compute("quality_metrics", metric_names=specified_metric_names) + qm_keys = small_sorting_analyzer.get_extension("quality_metrics").get_data().keys() + for i in range(3): + assert specified_metric_names[i] == qm_keys[i] + + +def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder): + """ + Computes quality metrics in binary folder format. Then computes subsets of quality + metrics and checks if they are saved correctly. + """ + + # can't use _misc_metric_name_to_func as some functions compute several qms + # e.g. isi_violation and synchrony + quality_metrics = [ + "num_spikes", + "firing_rate", + "presence_ratio", + "snr", + "isi_violations_ratio", + "isi_violations_count", + "rp_contamination", + "rp_violations", + "sliding_rp_violation", + "amplitude_cutoff", + "amplitude_median", + "amplitude_cv_median", + "amplitude_cv_range", + "sync_spike_2", + "sync_spike_4", + "sync_spike_8", + "firing_range", + "drift_ptp", + "drift_std", + "drift_mad", + "sd_ratio", + "isolation_distance", + "l_ratio", + "d_prime", + "silhouette", + "nn_hit_rate", + "nn_miss_rate", + ] + + small_sorting_analyzer.compute("quality_metrics") + + cache_folder = create_cache_folder + output_folder = cache_folder / "sorting_analyzer" + + folder_analyzer = small_sorting_analyzer.save_as(format="binary_folder", folder=output_folder) + quality_metrics_filename = output_folder / "extensions" / "quality_metrics" / "metrics.csv" + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metrics: + assert metric_name in metric_names + + folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=False) + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metrics: + assert metric_name in metric_names + + folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=True) + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metrics: + if metric_name == "snr": + assert metric_name in metric_names + else: + assert metric_name not in metric_names + + +def test_unit_structure_in_output(small_sorting_analyzer): + + qm_params = { + "presence_ratio": {"bin_duration_s": 0.1}, + "amplitude_cutoff": {"num_histogram_bins": 3}, + "amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3}, + "firing_range": {"bin_size_s": 1}, + "isi_violation": {"isi_threshold_ms": 10}, + "drift": {"interval_s": 1, "min_spikes_per_interval": 5, "min_fraction_valid_intervals": 0.2}, + "sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15}, + "rp_violation": {"refractory_period_ms": 10.0, "censored_period_ms": 0.0}, + } + + for metric in misc_metrics_list: + metric_name = metric.metric_name + metric_fun = metric.metric_function + try: + qm_param = qm_params[metric_name] + except: + qm_param = {} + + result_all = metric_fun(sorting_analyzer=small_sorting_analyzer, **qm_param) + result_sub = metric_fun(sorting_analyzer=small_sorting_analyzer, unit_ids=["#4", "#9"], **qm_param) + + error = "Problem with metric: " + metric_name + + if isinstance(result_all, dict): + assert list(result_all.keys()) == ["#3", "#9", "#4"], error + assert list(result_sub.keys()) == ["#4", "#9"], error + assert result_sub["#9"] == result_all["#9"], error + assert result_sub["#4"] == result_all["#4"], error + + else: + for result_ind, result in enumerate(result_sub): + + assert list(result_all[result_ind].keys()) == ["#3", "#9", "#4"], error + assert result_sub[result_ind].keys() == set(["#4", "#9"]), error + + assert result_sub[result_ind]["#9"] == result_all[result_ind]["#9"], error + assert result_sub[result_ind]["#4"] == result_all[result_ind]["#4"], error + + +def test_unit_id_order_independence(small_sorting_analyzer): + """ + Takes two almost-identical sorting_analyzers, whose unit_ids are in different orders and have different labels, + and checks that their calculated quality metrics are independent of the ordering and labelling. + """ + + recording = small_sorting_analyzer.recording + sorting = small_sorting_analyzer.sorting.select_units(["#4", "#9", "#3"], [1, 7, 2]) + + small_sorting_analyzer_2 = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") + + extensions_to_compute = { + "random_spikes": {"seed": 1205}, + "noise_levels": {"seed": 1205}, + "waveforms": {}, + "templates": {}, + "spike_amplitudes": {}, + "spike_locations": {}, + "principal_components": {}, + } + + small_sorting_analyzer_2.compute(extensions_to_compute) + + # need special params to get non-nan results on a short recording + qm_params = { + "presence_ratio": {"bin_duration_s": 0.1}, + "amplitude_cutoff": {"num_histogram_bins": 3}, + "amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3}, + "firing_range": {"bin_size_s": 1}, + "isi_violation": {"isi_threshold_ms": 10}, + "drift": {"interval_s": 1, "min_spikes_per_interval": 5}, + "sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15}, + } + quality_metrics_1 = compute_quality_metrics( + small_sorting_analyzer, metric_names=get_quality_metric_list(), metric_params=qm_params, skip_pc_metrics=True + ) + quality_metrics_2 = compute_quality_metrics( + small_sorting_analyzer_2, metric_names=get_quality_metric_list(), metric_params=qm_params, skip_pc_metrics=True + ) + + for metric, metric_2_data in quality_metrics_2.items(): + error = "Problem with the metric " + metric + assert quality_metrics_1[metric]["#3"] == metric_2_data[2], error + assert quality_metrics_1[metric]["#9"] == metric_2_data[7], error + assert quality_metrics_1[metric]["#4"] == metric_2_data[1], error + + +if __name__ == "__main__": + pass + # sorting_analyzer = _sorting_analyzer_simple() + # print(sorting_analyzer) + # test_unit_structure_in_output(_small_sorting_analyzer()) + # test_calculate_firing_rate_num_spikes(sorting_analyzer) # test_calculate_snrs(sorting_analyzer) # test_calculate_amplitude_cutoff(sorting_analyzer) # test_calculate_presence_ratio(sorting_analyzer) From 0fe7f3e7778a826562dd2d3a9667684a49861cb1 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 20 Jan 2026 17:38:18 +0100 Subject: [PATCH 27/40] Fix periods arg in MetricExtensions --- src/spikeinterface/metrics/quality/quality_metrics.py | 2 ++ src/spikeinterface/metrics/template/template_metrics.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/spikeinterface/metrics/quality/quality_metrics.py b/src/spikeinterface/metrics/quality/quality_metrics.py index 8e96f4dcaf..e5cc2aa323 100644 --- a/src/spikeinterface/metrics/quality/quality_metrics.py +++ b/src/spikeinterface/metrics/quality/quality_metrics.py @@ -70,6 +70,7 @@ def _set_params( metric_params: dict | None = None, delete_existing_metrics: bool = False, metrics_to_compute: list[str] | None = None, + periods=None, # common extension kwargs peak_sign=None, seed=None, @@ -90,6 +91,7 @@ def _set_params( metric_params=metric_params, delete_existing_metrics=delete_existing_metrics, metrics_to_compute=metrics_to_compute, + periods=periods, peak_sign=peak_sign, seed=seed, skip_pc_metrics=skip_pc_metrics, diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py index e27f16963d..85ef9e22cb 100644 --- a/src/spikeinterface/metrics/template/template_metrics.py +++ b/src/spikeinterface/metrics/template/template_metrics.py @@ -131,6 +131,7 @@ def _set_params( metric_params: dict | None = None, delete_existing_metrics: bool = False, metrics_to_compute: list[str] | None = None, + periods=None, # common extension kwargs peak_sign="neg", upsampling_factor=10, @@ -160,6 +161,7 @@ def _set_params( metric_params=metric_params, delete_existing_metrics=delete_existing_metrics, metrics_to_compute=metrics_to_compute, + periods=periods, # template metrics do not use periods peak_sign=peak_sign, upsampling_factor=upsampling_factor, include_multi_channel_metrics=include_multi_channel_metrics, From f087e08ec7b5df7b7fdb122d78e6a4b90b2d116c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 20 Jan 2026 17:57:57 +0100 Subject: [PATCH 28/40] Make bin edges unique --- .../metrics/quality/tests/test_metrics_functions.py | 3 --- src/spikeinterface/metrics/utils.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index 2e31c53135..f29e72d153 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -123,12 +123,10 @@ def compute_periods(sorting_analyzer, num_periods, bin_size_s=None): for segment_index in range(sorting_analyzer.recording.get_num_segments()): samples_per_period = sorting_analyzer.get_num_samples(segment_index) // num_periods if bin_size_s is not None: - print(f"Original samples_per_period: {samples_per_period} - num_periods: {num_periods}") bin_size_samples = int(bin_size_s * sorting_analyzer.sampling_frequency) print(samples_per_period / bin_size_samples) samples_per_period = samples_per_period // bin_size_samples * bin_size_samples num_periods = int(np.round(sorting_analyzer.get_num_samples(segment_index) / samples_per_period)) - print(f"Adjusted samples_per_period: {samples_per_period} - num_periods: {num_periods}") for unit_index, unit_id in enumerate(sorting_analyzer.unit_ids): period_starts = np.arange(0, sorting_analyzer.get_num_samples(segment_index), samples_per_period) periods_per_unit = np.zeros(len(period_starts), dtype=unit_period_dtype) @@ -138,7 +136,6 @@ def compute_periods(sorting_analyzer, num_periods, bin_size_s=None): periods_per_unit[i]["start_sample_index"] = period_start periods_per_unit[i]["end_sample_index"] = period_end periods_per_unit[i]["unit_index"] = unit_index - print(periods_per_unit, sorting_analyzer.get_num_samples(segment_index), samples_per_period) all_periods.append(periods_per_unit) return np.concatenate(all_periods) diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py index 83ddfcf90b..00db100c1f 100644 --- a/src/spikeinterface/metrics/utils.py +++ b/src/spikeinterface/metrics/utils.py @@ -37,7 +37,7 @@ def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, per end_sample = seg_start + period["end_sample_index"] end_sample = end_sample // bin_duration_samples * bin_duration_samples + 1 # align to bin bin_edges.extend(np.arange(start_sample, end_sample, bin_duration_samples)) - bin_edges_for_units[unit_id] = np.array(bin_edges) + bin_edges_for_units[unit_id] = np.unique(np.array(bin_edges)) else: for unit_id in sorting.unit_ids: bin_edges = [] From 173e7473034089ed27b080958511d939415b5533 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 21 Jan 2026 10:05:28 +0100 Subject: [PATCH 29/40] Add support_periods to spike train metrics and tests --- src/spikeinterface/metrics/conftest.py | 81 +++++++++++++++++- .../metrics/quality/tests/conftest.py | 85 ------------------- .../quality/tests/test_metrics_functions.py | 43 +--------- .../metrics/spiketrain/metrics.py | 4 +- .../spiketrain/tests/test_metric_functions.py | 33 +++++++ src/spikeinterface/metrics/utils.py | 42 +++++++++ 6 files changed, 158 insertions(+), 130 deletions(-) delete mode 100644 src/spikeinterface/metrics/quality/tests/conftest.py create mode 100644 src/spikeinterface/metrics/spiketrain/tests/test_metric_functions.py diff --git a/src/spikeinterface/metrics/conftest.py b/src/spikeinterface/metrics/conftest.py index 8d32c103fa..5313e763c1 100644 --- a/src/spikeinterface/metrics/conftest.py +++ b/src/spikeinterface/metrics/conftest.py @@ -1,8 +1,85 @@ import pytest -from spikeinterface.postprocessing.tests.conftest import _small_sorting_analyzer +from spikeinterface.core import ( + generate_ground_truth_recording, + create_sorting_analyzer, +) + +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") + + +def make_small_analyzer(): + recording, sorting = generate_ground_truth_recording( + durations=[10.0], + num_units=10, + seed=1205, + ) + + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + + sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"]) + + sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") + + extensions_to_compute = { + "random_spikes": {"seed": 1205}, + "noise_levels": {"seed": 1205}, + "waveforms": {}, + "templates": {"operators": ["average", "median"]}, + "spike_amplitudes": {}, + "spike_locations": {}, + "principal_components": {}, + } + + sorting_analyzer.compute(extensions_to_compute) + + return sorting_analyzer @pytest.fixture(scope="module") def small_sorting_analyzer(): - return _small_sorting_analyzer() + return make_small_analyzer() + + +@pytest.fixture(scope="module") +def sorting_analyzer_simple(): + # we need high firing rate for amplitude_cutoff + recording, sorting = generate_ground_truth_recording( + durations=[ + 120.0, + ], + sampling_frequency=30_000.0, + num_channels=6, + num_units=10, + generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), + generate_unit_locations_kwargs=dict( + margin_um=5.0, + minimum_z=5.0, + maximum_z=20.0, + ), + generate_templates_kwargs=dict( + unit_params=dict( + alpha=(200.0, 500.0), + ) + ), + noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + seed=1205, + ) + + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) + + sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205) + sorting_analyzer.compute("noise_levels") + sorting_analyzer.compute("waveforms", **job_kwargs) + sorting_analyzer.compute("templates") + sorting_analyzer.compute(["spike_amplitudes", "spike_locations"], **job_kwargs) + + return sorting_analyzer diff --git a/src/spikeinterface/metrics/quality/tests/conftest.py b/src/spikeinterface/metrics/quality/tests/conftest.py deleted file mode 100644 index 5313e763c1..0000000000 --- a/src/spikeinterface/metrics/quality/tests/conftest.py +++ /dev/null @@ -1,85 +0,0 @@ -import pytest - -from spikeinterface.core import ( - generate_ground_truth_recording, - create_sorting_analyzer, -) - -job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") - - -def make_small_analyzer(): - recording, sorting = generate_ground_truth_recording( - durations=[10.0], - num_units=10, - seed=1205, - ) - - channel_ids_as_integers = [id for id in range(recording.get_num_channels())] - unit_ids_as_integers = [id for id in range(sorting.get_num_units())] - recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) - sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) - - sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"]) - - sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") - - extensions_to_compute = { - "random_spikes": {"seed": 1205}, - "noise_levels": {"seed": 1205}, - "waveforms": {}, - "templates": {"operators": ["average", "median"]}, - "spike_amplitudes": {}, - "spike_locations": {}, - "principal_components": {}, - } - - sorting_analyzer.compute(extensions_to_compute) - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def small_sorting_analyzer(): - return make_small_analyzer() - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - # we need high firing rate for amplitude_cutoff - recording, sorting = generate_ground_truth_recording( - durations=[ - 120.0, - ], - sampling_frequency=30_000.0, - num_channels=6, - num_units=10, - generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), - generate_unit_locations_kwargs=dict( - margin_um=5.0, - minimum_z=5.0, - maximum_z=20.0, - ), - generate_templates_kwargs=dict( - unit_params=dict( - alpha=(200.0, 500.0), - ) - ), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), - seed=1205, - ) - - channel_ids_as_integers = [id for id in range(recording.get_num_channels())] - unit_ids_as_integers = [id for id in range(sorting.get_num_units())] - recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) - sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) - - sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - - sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205) - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("waveforms", **job_kwargs) - sorting_analyzer.compute("templates") - sorting_analyzer.compute(["spike_amplitudes", "spike_locations"], **job_kwargs) - - return sorting_analyzer diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index f29e72d153..0356e24ed0 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -11,7 +11,7 @@ synthesize_random_firings, ) -from spikeinterface.metrics.utils import create_ground_truth_pc_distributions +from spikeinterface.metrics.utils import create_ground_truth_pc_distributions, compute_periods from spikeinterface.metrics.quality import ( get_quality_metric_list, @@ -99,47 +99,6 @@ def sorting_analyzer_violations(): return _sorting_analyzer_violations() -def compute_periods(sorting_analyzer, num_periods, bin_size_s=None): - """ - Computes and sets periods for each unit in the sorting analyzer. - The periods span the total duration of the recording, but divide it into - smaller periods either by specifying the number of periods or the size of each bin. - - Parameters - ---------- - sorting_analyzer : SortingAnalyzer - The sorting analyzer containing the units and recording information. - num_periods : int - The number of periods to divide the total duration into (used if bin_size_s is None). - bin_size_s : float, defaut: None - If given, periods will be multiple of this size in seconds. - - Returns - ------- - periods - np.ndarray of dtype unit_period_dtype containing the segment, start, end samples and unit index. - """ - all_periods = [] - for segment_index in range(sorting_analyzer.recording.get_num_segments()): - samples_per_period = sorting_analyzer.get_num_samples(segment_index) // num_periods - if bin_size_s is not None: - bin_size_samples = int(bin_size_s * sorting_analyzer.sampling_frequency) - print(samples_per_period / bin_size_samples) - samples_per_period = samples_per_period // bin_size_samples * bin_size_samples - num_periods = int(np.round(sorting_analyzer.get_num_samples(segment_index) / samples_per_period)) - for unit_index, unit_id in enumerate(sorting_analyzer.unit_ids): - period_starts = np.arange(0, sorting_analyzer.get_num_samples(segment_index), samples_per_period) - periods_per_unit = np.zeros(len(period_starts), dtype=unit_period_dtype) - for i, period_start in enumerate(period_starts): - period_end = min(period_start + samples_per_period, sorting_analyzer.get_num_samples(segment_index)) - periods_per_unit[i]["segment_index"] = segment_index - periods_per_unit[i]["start_sample_index"] = period_start - periods_per_unit[i]["end_sample_index"] = period_end - periods_per_unit[i]["unit_index"] = unit_index - all_periods.append(periods_per_unit) - return np.concatenate(all_periods) - - @pytest.fixture def periods_simple(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple diff --git a/src/spikeinterface/metrics/spiketrain/metrics.py b/src/spikeinterface/metrics/spiketrain/metrics.py index 669733f47a..652be32955 100644 --- a/src/spikeinterface/metrics/spiketrain/metrics.py +++ b/src/spikeinterface/metrics/spiketrain/metrics.py @@ -40,6 +40,7 @@ class NumSpikes(BaseMetric): metric_params = {} metric_descriptions = {"num_spikes": "Total number of spikes for each unit across all segments."} metric_columns = {"num_spikes": int} + supports_periods = True def compute_firing_rates(sorting_analyzer, unit_ids=None, periods=None): @@ -68,7 +69,7 @@ def compute_firing_rates(sorting_analyzer, unit_ids=None, periods=None): total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) firing_rates = {} - num_spikes = compute_num_spikes(sorting_analyzer, unit_ids=unit_ids) + num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) for unit_id in unit_ids: if num_spikes[unit_id] == 0: firing_rates[unit_id] = np.nan @@ -83,6 +84,7 @@ class FiringRate(BaseMetric): metric_params = {} metric_descriptions = {"firing_rate": "Firing rate (spikes per second) for each unit across all segments."} metric_columns = {"firing_rate": float} + supports_periods = True spiketrain_metrics = [NumSpikes, FiringRate] diff --git a/src/spikeinterface/metrics/spiketrain/tests/test_metric_functions.py b/src/spikeinterface/metrics/spiketrain/tests/test_metric_functions.py new file mode 100644 index 0000000000..86a5e9db2d --- /dev/null +++ b/src/spikeinterface/metrics/spiketrain/tests/test_metric_functions.py @@ -0,0 +1,33 @@ +import numpy as np + +from spikeinterface.core.base import unit_period_dtype +from spikeinterface.metrics.utils import compute_periods +from spikeinterface.metrics.spiketrain.metrics import compute_num_spikes, compute_firing_rates + + +def test_calculate_num_spikes(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + # spike_amps = sorting_analyzer.get_extension("spike_amplitudes").get_data() + num_spikes = compute_num_spikes(sorting_analyzer) + periods = compute_periods(sorting_analyzer, num_periods=5) + num_spikes_periods = compute_num_spikes(sorting_analyzer, periods=periods) + assert num_spikes == num_spikes_periods + + # calculate num spikes with empty periods + empty_periods = np.empty(0, dtype=unit_period_dtype) + num_spikes_empty_periods = compute_num_spikes(sorting_analyzer, periods=empty_periods) + assert num_spikes_empty_periods == {unit_id: 0 for unit_id in sorting_analyzer.sorting.unit_ids} + + +def test_calculate_firing_rates(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + # spike_amps = sorting_analyzer.get_extension("spike_amplitudes").get_data() + firing_rates = compute_firing_rates(sorting_analyzer) + periods = compute_periods(sorting_analyzer, num_periods=5) + firing_rates_periods = compute_firing_rates(sorting_analyzer, periods=periods) + assert firing_rates == firing_rates_periods + + # calculate num spikes with empty periods + empty_periods = np.empty(0, dtype=unit_period_dtype) + firing_rates_empty_periods = compute_firing_rates(sorting_analyzer, periods=empty_periods) + assert np.all(np.isnan(np.array(list(firing_rates_empty_periods.values())))) diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py index 00db100c1f..e007b19c05 100644 --- a/src/spikeinterface/metrics/utils.py +++ b/src/spikeinterface/metrics/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np +from spikeinterface.core.base import unit_period_dtype def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, periods=None): @@ -108,6 +109,47 @@ def compute_total_durations_per_unit(sorting_analyzer, periods=None): return total_durations +def compute_periods(sorting_analyzer, num_periods, bin_size_s=None): + """ + Computes and sets periods for each unit in the sorting analyzer. + The periods span the total duration of the recording, but divide it into + smaller periods either by specifying the number of periods or the size of each bin. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer containing the units and recording information. + num_periods : int + The number of periods to divide the total duration into (used if bin_size_s is None). + bin_size_s : float, defaut: None + If given, periods will be multiple of this size in seconds. + + Returns + ------- + periods + np.ndarray of dtype unit_period_dtype containing the segment, start, end samples and unit index. + """ + all_periods = [] + for segment_index in range(sorting_analyzer.recording.get_num_segments()): + samples_per_period = sorting_analyzer.get_num_samples(segment_index) // num_periods + if bin_size_s is not None: + bin_size_samples = int(bin_size_s * sorting_analyzer.sampling_frequency) + print(samples_per_period / bin_size_samples) + samples_per_period = samples_per_period // bin_size_samples * bin_size_samples + num_periods = int(np.round(sorting_analyzer.get_num_samples(segment_index) / samples_per_period)) + for unit_index, unit_id in enumerate(sorting_analyzer.unit_ids): + period_starts = np.arange(0, sorting_analyzer.get_num_samples(segment_index), samples_per_period) + periods_per_unit = np.zeros(len(period_starts), dtype=unit_period_dtype) + for i, period_start in enumerate(period_starts): + period_end = min(period_start + samples_per_period, sorting_analyzer.get_num_samples(segment_index)) + periods_per_unit[i]["segment_index"] = segment_index + periods_per_unit[i]["start_sample_index"] = period_start + periods_per_unit[i]["end_sample_index"] = period_end + periods_per_unit[i]["unit_index"] = unit_index + all_periods.append(periods_per_unit) + return np.concatenate(all_periods) + + def create_ground_truth_pc_distributions(center_locations, total_points): """ Simulate PCs as multivariate Gaussians, for testing PC-based quality metrics From 066c3787171c46ef2bd39727bcbede3e015ec86d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 21 Jan 2026 12:11:30 +0100 Subject: [PATCH 30/40] Force NaN/-1 values for float/int metrics if num_spikes is 0 --- .../metrics/quality/misc_metrics.py | 44 +++++++++-- .../quality/tests/test_metrics_functions.py | 73 ++++++++++++++++++- 2 files changed, 108 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 04d451202d..7bf7ff0f86 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -75,6 +75,7 @@ def compute_presence_ratios( if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() + num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) segment_samples = [sorting_analyzer.get_num_samples(i) for i in range(num_segs)] total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) @@ -104,6 +105,9 @@ def compute_presence_ratios( else: for unit_id in unit_ids: + if num_spikes[unit_id] == 0: + presence_ratios[unit_id] = np.nan + continue spike_train = [] bin_edges = bin_edges_per_unit[unit_id] if len(bin_edges) < 2: @@ -264,6 +268,7 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, periods=None, isi_th unit_ids = sorting_analyzer.unit_ids total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) + num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) fs = sorting_analyzer.sampling_frequency isi_threshold_s = isi_threshold_ms / 1000 @@ -273,15 +278,17 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, periods=None, isi_th isi_violations_ratio = {} for unit_id in unit_ids: + if num_spikes[unit_id] == 0: + isi_violations_ratio[unit_id] = np.nan + isi_violations_count[unit_id] = -1 + continue + spike_train_list = [] for segment_index in range(sorting_analyzer.get_num_segments()): spike_train = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) if len(spike_train) > 0: spike_train_list.append(spike_train / fs) - if not any([len(train) > 0 for train in spike_train_list]): - continue - total_duration = total_durations[unit_id] ratio, _, count = isi_violations(spike_train_list, total_duration, isi_threshold_s, min_isi_s) @@ -359,7 +366,7 @@ def compute_refrac_period_violations( if not HAVE_NUMBA: warnings.warn("Error: numba is not installed.") warnings.warn("compute_refrac_period_violations cannot run without numba.") - return {unit_id: np.nan for unit_id in unit_ids} + return res({unit_id: np.nan for unit_id in unit_ids}, {unit_id: 0 for unit_id in unit_ids}) num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) @@ -372,6 +379,11 @@ def compute_refrac_period_violations( nb_violations = {} rp_contamination = {} for unit_id in unit_ids: + if num_spikes[unit_id] == 0: + rp_contamination[unit_id] = np.nan + nb_violations[unit_id] = -1 + continue + nb_violations[unit_id] = 0 total_samples_unit = total_samples[unit_id] @@ -556,7 +568,7 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, periods=None, syn if unit_ids is None: unit_ids = sorting.unit_ids - spike_counts = sorting_analyzer.sorting.count_num_spikes_per_unit(unit_ids=unit_ids) + num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) spikes = sorting.to_spike_vector() all_unit_ids = sorting.unit_ids @@ -569,10 +581,10 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, periods=None, syn for i, unit_id in enumerate(all_unit_ids): if unit_id not in unit_ids: continue - if spike_counts[unit_id] != 0: - sync_id_metrics_dict[unit_id] = synchrony_counts[sync_idx][i] / spike_counts[unit_id] + if num_spikes[unit_id] != 0: + sync_id_metrics_dict[unit_id] = synchrony_counts[sync_idx][i] / num_spikes[unit_id] else: - sync_id_metrics_dict[unit_id] = 0 + sync_id_metrics_dict[unit_id] = -1 synchrony_metrics_dict[f"sync_spike_{synchrony_size}"] = sync_id_metrics_dict return res(**synchrony_metrics_dict) @@ -629,6 +641,8 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, periods=None, bin_siz if unit_ids is None: unit_ids = sorting.unit_ids + num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) + if all( [ sorting_analyzer.get_num_samples(segment_index) < bin_size_samples @@ -648,6 +662,8 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, periods=None, bin_siz ) cumulative_segment_samples = np.cumsum([0] + segment_samples[:-1]) for unit_id in unit_ids: + if num_spikes[unit_id] == 0: + continue bin_edges = bin_edges_per_unit[unit_id] # we can concatenate spike trains across segments adding the cumulative number of samples @@ -665,6 +681,9 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, periods=None, bin_siz # finally we compute the percentiles firing_ranges = {} for unit_id in unit_ids: + if num_spikes[unit_id] == 0: + firing_ranges[unit_id] = np.nan + continue firing_ranges[unit_id] = np.percentile(firing_rate_histograms[unit_id], percentiles[1]) - np.percentile( firing_rate_histograms[unit_id], percentiles[0] ) @@ -748,6 +767,10 @@ def compute_amplitude_cv_metrics( amplitude_cv_medians, amplitude_cv_ranges = {}, {} for unit_id in unit_ids: + if num_spikes[unit_id] == 0: + amplitude_cv_medians[unit_id] = np.nan + amplitude_cv_ranges[unit_id] = np.nan + continue total_duration = total_durations[unit_id] firing_rate = num_spikes[unit_id] / total_duration temporal_bin_size_samples = int( @@ -1267,6 +1290,8 @@ def compute_sd_ratio( if unit_ids is None: unit_ids = sorting_analyzer.unit_ids + num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) + if not sorting_analyzer.has_recording(): warnings.warn( "The `sd_ratio` metric cannot work with a recordless SortingAnalyzer object" @@ -1297,6 +1322,9 @@ def compute_sd_ratio( sd_ratio = {} for unit_id in unit_ids: + if num_spikes[unit_id] == 0: + sd_ratio[unit_id] = np.nan + continue spk_amp = [] for segment_index in range(sorting_analyzer.get_num_segments()): spike_train = sorting.get_unit_spike_train(unit_id, segment_index) diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index 0356e24ed0..c13f1ffbaa 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -273,6 +273,10 @@ def test_calculate_firing_range(sorting_analyzer_simple): firing_ranges_periods = compute_firing_ranges(sorting_analyzer, periods=periods, bin_size_s=1) assert firing_ranges == firing_ranges_periods + empty_periods = np.empty(0, dtype=unit_period_dtype) + firing_ranges_empty = compute_firing_ranges(sorting_analyzer, periods=empty_periods) + assert np.all(np.isnan(np.array(list(firing_ranges_empty.values())))) + with pytest.warns(UserWarning) as w: firing_ranges_nan = compute_firing_ranges( sorting_analyzer, bin_size_s=sorting_analyzer.get_total_duration() + 1 @@ -287,6 +291,10 @@ def test_calculate_amplitude_cutoff(sorting_analyzer_simple): periods = compute_periods(sorting_analyzer, num_periods=5) amp_cuts_periods = compute_amplitude_cutoffs(sorting_analyzer, periods=periods, num_histogram_bins=10) assert amp_cuts == amp_cuts_periods + + empty_periods = np.empty(0, dtype=unit_period_dtype) + amp_cuts_empty = compute_amplitude_cutoffs(sorting_analyzer, periods=empty_periods) + assert np.all(np.isnan(np.array(list(amp_cuts_empty.values())))) # print(amp_cuts) # testing method accuracy with magic number is not a good pratcice, I remove this. @@ -302,6 +310,10 @@ def test_calculate_amplitude_median(sorting_analyzer_simple): amp_medians_periods = compute_amplitude_medians(sorting_analyzer, periods=periods) assert amp_medians == amp_medians_periods + empty_periods = np.empty(0, dtype=unit_period_dtype) + amp_medians_empty = compute_amplitude_medians(sorting_analyzer, periods=empty_periods) + assert np.all(np.isnan(np.array(list(amp_medians_empty.values())))) + # testing method accuracy with magic number is not a good pratcice, I remove this. # amp_medians_gt = {0: 130.77323354628675, 1: 130.7461997791725, 2: 130.7461997791725} # assert np.allclose(list(amp_medians_gt.values()), list(amp_medians.values()), rtol=0.05) @@ -319,6 +331,15 @@ def test_calculate_amplitude_cv_metrics(sorting_analyzer_simple, periods_simple) assert amp_cv_median == amp_cv_median_periods assert amp_cv_range == amp_cv_range_periods + empty_periods = np.empty(0, dtype=unit_period_dtype) + amp_cv_median_empty, amp_cv_range_empty = compute_amplitude_cv_metrics( + sorting_analyzer, + periods=empty_periods, + average_num_spikes_per_bin=20, + ) + assert np.all(np.isnan(np.array(list(amp_cv_median_empty.values())))) + assert np.all(np.isnan(np.array(list(amp_cv_range_empty.values())))) + # amps_scalings = compute_amplitude_scalings(sorting_analyzer) sorting_analyzer.compute("amplitude_scalings", **job_kwargs) amp_cv_median_scalings, amp_cv_range_scalings = compute_amplitude_cv_metrics( @@ -354,6 +375,10 @@ def test_calculate_presence_ratio(sorting_analyzer_simple, periods_simple): periods = periods_simple ratios_periods = compute_presence_ratios(sorting_analyzer, periods=periods, bin_duration_s=10) assert ratios == ratios_periods + + empty_periods = np.empty(0, dtype=unit_period_dtype) + ratios_periods_empty = compute_presence_ratios(sorting_analyzer, periods=empty_periods) + assert np.all(np.isnan(np.array(list(ratios_periods_empty.values())))) # testing method accuracy with magic number is not a good pratcice, I remove this. # ratios_gt = {0: 1.0, 1: 1.0, 2: 1.0} # np.testing.assert_array_equal(list(ratios_gt.values()), list(ratios.values())) @@ -367,6 +392,12 @@ def test_calculate_isi_violations(sorting_analyzer_violations, periods_violation sorting_analyzer, isi_threshold_ms=1, min_isi_ms=0.0, periods=periods ) assert isi_viol == isi_viol_periods + assert counts == counts_periods + + empty_periods = np.empty(0, dtype=unit_period_dtype) + isi_viol_empty, isi_counts_empty = compute_isi_violations(sorting_analyzer, periods=empty_periods) + assert np.all(np.isnan(np.array(list(isi_viol_empty.values())))) + assert np.array_equal(np.array(list(isi_counts_empty.values())), -1 * np.ones(len(sorting_analyzer.unit_ids))) # testing method accuracy with magic number is not a good pratcice, I remove this. # isi_viol_gt = {0: 0.0998002996004994, 1: 0.7904857139469347, 2: 1.929898371551754} @@ -384,6 +415,12 @@ def test_calculate_sliding_rp_violations(sorting_analyzer_violations, periods_vi ) assert contaminations == contaminations_periods + empty_periods = np.empty(0, dtype=unit_period_dtype) + contaminations_periods_empty = compute_sliding_rp_violations( + sorting_analyzer, periods=empty_periods, bin_size_ms=0.25, window_size_s=1 + ) + assert np.all(np.isnan(np.array(list(contaminations_periods_empty.values())))) + # testing method accuracy with magic number is not a good pratcice, I remove this. # contaminations_gt = {0: 0.03, 1: 0.185, 2: 0.325} # assert np.allclose(list(contaminations_gt.values()), list(contaminations.values()), rtol=0.05) @@ -399,6 +436,15 @@ def test_calculate_rp_violations(sorting_analyzer_violations, periods_violations sorting_analyzer, refractory_period_ms=1, censored_period_ms=0.0, periods=periods ) assert rp_contamination == rp_contamination_periods + assert counts == counts_periods + + empty_periods = np.empty(0, dtype=unit_period_dtype) + rp_contamination_empty, counts_empty = compute_refrac_period_violations( + sorting_analyzer, refractory_period_ms=1, censored_period_ms=0.0, periods=empty_periods + ) + assert np.all(np.isnan(np.array(list(rp_contamination_empty.values())))) + assert np.array_equal(np.array(list(counts_empty.values())), -1 * np.ones(len(sorting_analyzer.unit_ids))) + # testing method accuracy with magic number is not a good pratcice, I remove this. # counts_gt = {0: 2, 1: 4, 2: 10} # rp_contamination_gt = {0: 0.10534956502609294, 1: 1.0, 2: 1.0} @@ -425,8 +471,19 @@ def test_synchrony_metrics(sorting_analyzer_simple, periods_simple): synchrony_metrics_periods = compute_synchrony_metrics(sorting_analyzer, periods=periods) assert synchrony_metrics == synchrony_metrics_periods - synchrony_sizes = np.array([2, 4, 8]) + empty_periods = np.empty(0, dtype=unit_period_dtype) + synchrony_metrics_empty = compute_synchrony_metrics(sorting_analyzer, periods=empty_periods) + assert np.array_equal( + np.array(list(synchrony_metrics_empty.sync_spike_2.values())), -1 * np.ones(len(sorting_analyzer.unit_ids)) + ) + assert np.array_equal( + np.array(list(synchrony_metrics_empty.sync_spike_4.values())), -1 * np.ones(len(sorting_analyzer.unit_ids)) + ) + assert np.array_equal( + np.array(list(synchrony_metrics_empty.sync_spike_8.values())), -1 * np.ones(len(sorting_analyzer.unit_ids)) + ) + synchrony_sizes = np.array([2, 4, 8]) # check returns for size in synchrony_sizes: assert f"sync_spike_{size}" in synchrony_metrics._fields @@ -487,6 +544,15 @@ def test_calculate_drift_metrics(sorting_analyzer_simple): assert drifts_stds == drifts_stds_periods assert drift_mads == drift_mads_periods + # calculate num spikes with empty periods + empty_periods = np.empty(0, dtype=unit_period_dtype) + drifts_ptps_empty, drifts_stds_empty, drift_mads_empty = compute_drift_metrics( + sorting_analyzer_simple, periods=empty_periods + ) + assert np.all(np.isnan(np.array(list(drifts_ptps_empty.values())))) + assert np.all(np.isnan(np.array(list(drifts_stds_empty.values())))) + assert np.all(np.isnan(np.array(list(drift_mads_empty.values())))) + # print(drifts_ptps, drifts_stds, drift_mads) # testing method accuracy with magic number is not a good pratcice, I remove this. @@ -507,6 +573,11 @@ def test_calculate_sd_ratio(sorting_analyzer_simple, periods_simple): assert sd_ratio == sd_ratio_periods assert np.all(list(sd_ratio.keys()) == sorting_analyzer_simple.unit_ids) + + # calculate num spikes with empty periods + empty_periods = np.empty(0, dtype=unit_period_dtype) + sd_ratios_empty_periods = compute_sd_ratio(sorting_analyzer_simple, periods=empty_periods) + assert np.all(np.isnan(np.array(list(sd_ratios_empty_periods.values())))) # @aurelien can you check this, this is not working anymore # assert np.allclose(list(sd_ratio.values()), 1, atol=0.25, rtol=0) From 65e18488860ce63d8766834bec4efc09602f6df7 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 21 Jan 2026 12:27:52 +0100 Subject: [PATCH 31/40] Fix test_empty_units: -1 is a valid value for ints --- .../metrics/quality/tests/test_quality_metric_calculator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py b/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py index ec72fdc178..2e87002018 100644 --- a/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py @@ -168,7 +168,8 @@ def test_empty_units(sorting_analyzer_simple): for col in metrics_empty.columns: all_nans = np.all(isnull(metrics_empty.loc[empty_unit_ids, col].values)) all_zeros = np.all(metrics_empty.loc[empty_unit_ids, col].values == 0) - assert all_nans or all_zeros + all_neg_ones = np.all(metrics_empty.loc[empty_unit_ids, col].values == -1) + assert all_nans or all_zeros or all_neg_ones, f"Column {col} failed the empty unit test" if __name__ == "__main__": From f1c46828a5cf40944c2d0f8800bba6fdee484197 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 21 Jan 2026 12:39:45 +0100 Subject: [PATCH 32/40] Fix firing range if unit samples < bin samples --- src/spikeinterface/metrics/quality/misc_metrics.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 7bf7ff0f86..4556fcf09a 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -642,15 +642,7 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, periods=None, bin_siz unit_ids = sorting.unit_ids num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) - - if all( - [ - sorting_analyzer.get_num_samples(segment_index) < bin_size_samples - for segment_index in range(sorting_analyzer.get_num_segments()) - ] - ): - warnings.warn(f"Bin size of {bin_size_s}s is larger than each segment duration. Firing ranges are set to NaN.") - return {unit_id: np.nan for unit_id in unit_ids} + total_samples = compute_total_samples_per_unit(sorting_analyzer, periods=periods) # for each segment, we compute the firing rate histogram and we concatenate them firing_rate_histograms = {unit_id: np.array([], dtype=float) for unit_id in unit_ids} @@ -662,7 +654,7 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, periods=None, bin_siz ) cumulative_segment_samples = np.cumsum([0] + segment_samples[:-1]) for unit_id in unit_ids: - if num_spikes[unit_id] == 0: + if num_spikes[unit_id] == 0 or total_samples[unit_id] < bin_size_samples: continue bin_edges = bin_edges_per_unit[unit_id] @@ -681,7 +673,7 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, periods=None, bin_siz # finally we compute the percentiles firing_ranges = {} for unit_id in unit_ids: - if num_spikes[unit_id] == 0: + if num_spikes[unit_id] == 0 or total_samples[unit_id] < bin_size_samples: firing_ranges[unit_id] = np.nan continue firing_ranges[unit_id] = np.percentile(firing_rate_histograms[unit_id], percentiles[1]) - np.percentile( From 32916382980290b6b75a99c0e9c0f524486af92b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 21 Jan 2026 12:51:19 +0100 Subject: [PATCH 33/40] fix noise_cutoff if empty units --- src/spikeinterface/metrics/quality/misc_metrics.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 4556fcf09a..83ac82bd73 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -1018,6 +1018,10 @@ def compute_noise_cutoffs( for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] + if len(amplitudes) == 0: + cutoff, ratio = np.nan, np.nan + continue + if invert_amplitudes: amplitudes = -amplitudes From b5bf3c3f03fde75a007133ff356dd6a278a6f34c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 21 Jan 2026 13:10:56 +0100 Subject: [PATCH 34/40] Move warnings at the end of the loop for firing range and drift --- .../metrics/quality/misc_metrics.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 83ac82bd73..ab8dc670d2 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -672,13 +672,20 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, periods=None, bin_siz # finally we compute the percentiles firing_ranges = {} + failed_units = [] for unit_id in unit_ids: if num_spikes[unit_id] == 0 or total_samples[unit_id] < bin_size_samples: + failed_units.append(unit_id) firing_ranges[unit_id] = np.nan continue firing_ranges[unit_id] = np.percentile(firing_rate_histograms[unit_id], percentiles[1]) - np.percentile( firing_rate_histograms[unit_id], percentiles[0] ) + if len(failed_units) > 0: + warnings.warn( + f"Firing range could not be computed for units {failed_units} " + f"because they have no spikes or the total duration is less than bin size." + ) return firing_ranges @@ -1156,18 +1163,16 @@ def compute_drift_metrics( bin_duration_s=interval_s, ) + failed_units = [] median_positions_per_unit = {} for unit_id in unit_ids: bins = bin_edges_for_units[unit_id] num_bins = len(bins) - 1 if num_bins < min_num_bins: - warnings.warn( - f"Unit {unit_id} has only {num_bins} bins given the specified 'interval_s' and " - f"'min_num_bins'. Drift metrics will be set to NaN" - ) drift_ptps[unit_id] = np.nan drift_stds[unit_id] = np.nan drift_mads[unit_id] = np.nan + failed_units.append(unit_id) continue # bin_edges are global across segments, so we have to use spike_sample_indices, @@ -1191,6 +1196,7 @@ def compute_drift_metrics( if np.any(np.isnan(position_diff)): # deal with nans: if more than 50% nans --> set to nan if np.sum(np.isnan(position_diff)) > min_fraction_valid_intervals * len(position_diff): + failed_units.append(unit_id) ptp_drift = np.nan std_drift = np.nan mad_drift = np.nan @@ -1206,6 +1212,12 @@ def compute_drift_metrics( drift_stds[unit_id] = std_drift drift_mads[unit_id] = mad_drift + if len(failed_units) > 0: + warnings.warn( + f"Drift metrics could not be computed for units {failed_units} because they have less than " + f"{min_num_bins} bins given the specified 'interval_s' and 'min_num_bins' or not enough valid intervals." + ) + if return_positions: outs = res(drift_ptps, drift_stds, drift_mads), median_positions_per_unit else: From 8aeedccf3fe36808f0ec145ca39b62bbe24e7855 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 22 Jan 2026 09:55:49 +0100 Subject: [PATCH 35/40] clean up tests and add get_available_metric_names --- .../core/analyzer_extension_core.py | 11 +++++ .../quality/tests/test_metrics_functions.py | 45 +++---------------- 2 files changed, 18 insertions(+), 38 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index a21404e58f..30038bc270 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -901,6 +901,17 @@ class BaseMetricExtension(AnalyzerExtension): need_backward_compatibility_on_load = False metric_list: list[BaseMetric] = None # list of BaseMetric + @classmethod + def get_available_metric_names(cls): + """Get the available metric names. + + Returns + ------- + available_metric_names : list[str] + List of available metric names. + """ + return [m.metric_name for m in cls.metric_list] + @classmethod def get_default_metric_params(cls): """Get the default metric parameters. diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index c13f1ffbaa..61f014c289 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -13,10 +13,7 @@ from spikeinterface.metrics.utils import create_ground_truth_pc_distributions, compute_periods -from spikeinterface.metrics.quality import ( - get_quality_metric_list, - compute_quality_metrics, -) +from spikeinterface.metrics.quality import get_quality_metric_list, compute_quality_metrics, ComputeQualityMetrics from spikeinterface.metrics.quality.misc_metrics import ( misc_metrics_list, compute_amplitude_cutoffs, @@ -657,37 +654,9 @@ def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder): # can't use _misc_metric_name_to_func as some functions compute several qms # e.g. isi_violation and synchrony - quality_metrics = [ - "num_spikes", - "firing_rate", - "presence_ratio", - "snr", - "isi_violations_ratio", - "isi_violations_count", - "rp_contamination", - "rp_violations", - "sliding_rp_violation", - "amplitude_cutoff", - "amplitude_median", - "amplitude_cv_median", - "amplitude_cv_range", - "sync_spike_2", - "sync_spike_4", - "sync_spike_8", - "firing_range", - "drift_ptp", - "drift_std", - "drift_mad", - "sd_ratio", - "isolation_distance", - "l_ratio", - "d_prime", - "silhouette", - "nn_hit_rate", - "nn_miss_rate", - ] - - small_sorting_analyzer.compute("quality_metrics") + quality_metric_columns = ComputeQualityMetrics.get_metric_columns() + all_metrics = ComputeQualityMetrics.get_available_metric_names() + small_sorting_analyzer.compute("quality_metrics", metric_names=all_metrics) cache_folder = create_cache_folder output_folder = cache_folder / "sorting_analyzer" @@ -699,7 +668,7 @@ def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder): saved_metrics = csv.reader(metrics_file) metric_names = next(saved_metrics) - for metric_name in quality_metrics: + for metric_name in quality_metric_columns: assert metric_name in metric_names folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=False) @@ -708,7 +677,7 @@ def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder): saved_metrics = csv.reader(metrics_file) metric_names = next(saved_metrics) - for metric_name in quality_metrics: + for metric_name in quality_metric_columns: assert metric_name in metric_names folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=True) @@ -717,7 +686,7 @@ def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder): saved_metrics = csv.reader(metrics_file) metric_names = next(saved_metrics) - for metric_name in quality_metrics: + for metric_name in quality_metric_columns: if metric_name == "snr": assert metric_name in metric_names else: From d4db43cab085ef362c6200f35f10914ff5273019 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 22 Jan 2026 09:57:34 +0100 Subject: [PATCH 36/40] simplify total samples --- src/spikeinterface/metrics/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py index e007b19c05..dea652985f 100644 --- a/src/spikeinterface/metrics/utils.py +++ b/src/spikeinterface/metrics/utils.py @@ -82,7 +82,8 @@ def compute_total_samples_per_unit(sorting_analyzer, periods=None): num_samples_in_period += period["end_sample_index"] - period["start_sample_index"] total_samples[unit_id] = num_samples_in_period else: - total_samples = {unit_id: sorting_analyzer.get_total_samples() for unit_id in sorting_analyzer.unit_ids} + total = sorting_analyzer.get_total_samples() + total_samples = {unit_id: total for unit_id in sorting_analyzer.unit_ids} return total_samples From d0a1e66c68127e41875ecb7b9d90d4dabf95be99 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 22 Jan 2026 10:24:01 +0100 Subject: [PATCH 37/40] Go back to Pierre's implementation for drifts --- .../metrics/quality/misc_metrics.py | 78 ++++++++----------- src/spikeinterface/metrics/utils.py | 37 +++++++-- 2 files changed, 62 insertions(+), 53 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index ab8dc670d2..198e98037c 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -1129,13 +1129,20 @@ def compute_drift_metrics( unit_ids = sorting.unit_ids spike_locations_ext = sorting_analyzer.get_extension("spike_locations") - spike_locations_array = spike_locations_ext.get_data(periods=periods) + spike_locations_by_unit_and_segments = spike_locations_ext.get_data( + outputs="by_unit", concatenated=False, periods=periods + ) spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit", concatenated=True, periods=periods) segment_samples = [sorting_analyzer.get_num_samples(i) for i in range(sorting_analyzer.get_num_segments())] - assert direction in spike_locations_array.dtype.names, ( - f"Direction {direction} is invalid. Available directions: " f"{spike_locations_array.dtype.names}" + data = spike_locations_by_unit[unit_ids[0]] + assert direction in data.dtype.names, ( + f"Direction {direction} is invalid. Available directions: " f"{data.dtype.names}" + ) + bin_edges_for_units = compute_bin_edges_per_unit( + sorting, segment_samples=segment_samples, periods=periods, bin_duration_s=interval_s, concatenated=False ) + failed_units = [] # we need drift_ptps = {} @@ -1144,62 +1151,43 @@ def compute_drift_metrics( # reference positions are the medians across segments reference_positions = {} + median_position_segments = {unit_id: np.array([]) for unit_id in unit_ids} + for unit_id in unit_ids: reference_positions[unit_id] = np.median(spike_locations_by_unit[unit_id][direction]) - # now compute median positions and concatenate them over segments - spike_vector = sorting.to_spike_vector() - spike_sample_indices = spike_vector["sample_index"].copy() - # we need to add the cumulative sum of segment samples to have global sample indices - cumulative_segment_samples = np.cumsum([0] + segment_samples[:-1]) for segment_index in range(sorting_analyzer.get_num_segments()): - segment_slice = sorting._get_spike_vector_segment_slices()[segment_index] - spike_sample_indices[segment_slice[0] : segment_slice[1]] += cumulative_segment_samples[segment_index] - - bin_edges_for_units = compute_bin_edges_per_unit( - sorting, - segment_samples=segment_samples, - periods=periods, - bin_duration_s=interval_s, - ) - - failed_units = [] - median_positions_per_unit = {} + for unit_id in unit_ids: + bins = bin_edges_for_units[unit_id][segment_index] + num_bin_edges = len(bins) + if (num_bin_edges - 1) < min_num_bins: + failed_units.append(unit_id) + continue + median_positions = np.nan * np.zeros((num_bin_edges - 1)) + spikes_in_segment_of_unit = sorting.get_unit_spike_train(unit_id, segment_index) + bounds = np.searchsorted(spikes_in_segment_of_unit, bins, side="left") + for bin_index, (i0, i1) in enumerate(zip(bounds[:-1], bounds[1:])): + spike_locations_in_bin = spike_locations_by_unit_and_segments[segment_index][unit_id][i0:i1][direction] + if (i1 - i0) >= min_spikes_per_interval: + median_positions[bin_index] = np.median(spike_locations_in_bin) + median_position_segments[unit_id] = np.concatenate((median_position_segments[unit_id], median_positions)) + + # finally, compute deviations and drifts for unit_id in unit_ids: - bins = bin_edges_for_units[unit_id] - num_bins = len(bins) - 1 - if num_bins < min_num_bins: + # Skip units that already failed because not enough bins in at least one segment + if unit_id in failed_units: drift_ptps[unit_id] = np.nan drift_stds[unit_id] = np.nan drift_mads[unit_id] = np.nan - failed_units.append(unit_id) continue - - # bin_edges are global across segments, so we have to use spike_sample_indices, - # since we offseted them to be global - bin_spike_indices = np.searchsorted(spike_sample_indices, bins) - median_positions = np.nan * np.zeros(num_bins) - for bin_index, (i0, i1) in enumerate(zip(bin_spike_indices[:-1], bin_spike_indices[1:])): - spikes_in_bin = spike_vector[i0:i1] - spike_locations_in_bin = spike_locations_array[i0:i1][direction] - - unit_index = sorting_analyzer.sorting.id_to_index(unit_id) - mask = spikes_in_bin["unit_index"] == unit_index - if np.sum(mask) >= min_spikes_per_interval: - median_positions[bin_index] = np.median(spike_locations_in_bin[mask]) - else: - median_positions[bin_index] = np.nan - median_positions_per_unit[unit_id] = median_positions - - # now compute deviations and drifts for this unit - position_diff = median_positions - reference_positions[unit_id] + position_diff = median_position_segments[unit_id] - reference_positions[unit_id] if np.any(np.isnan(position_diff)): # deal with nans: if more than 50% nans --> set to nan if np.sum(np.isnan(position_diff)) > min_fraction_valid_intervals * len(position_diff): - failed_units.append(unit_id) ptp_drift = np.nan std_drift = np.nan mad_drift = np.nan + failed_units.append(unit_id) else: ptp_drift = np.nanmax(position_diff) - np.nanmin(position_diff) std_drift = np.nanstd(np.abs(position_diff)) @@ -1219,7 +1207,7 @@ def compute_drift_metrics( ) if return_positions: - outs = res(drift_ptps, drift_stds, drift_mads), median_positions_per_unit + outs = res(drift_ptps, drift_stds, drift_mads), median_positions else: outs = res(drift_ptps, drift_stds, drift_mads) return outs diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py index dea652985f..235ae5cd16 100644 --- a/src/spikeinterface/metrics/utils.py +++ b/src/spikeinterface/metrics/utils.py @@ -4,7 +4,7 @@ from spikeinterface.core.base import unit_period_dtype -def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, periods=None): +def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, periods=None, concatenated=True): """ Compute bin edges for units, optionally taking into account periods. @@ -18,6 +18,16 @@ def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, per Duration of each bin in seconds periods : array of unit_period_dtype, default: None Periods to consider for each unit + concatenated : bool, default: True + Wheter the bins are concatenated across segments or not. + If False, the bin edges are computed per segment and the first index of each segment is 0. + If True, the bin edges are computed on the concatenated segments, with the correct offsets. + + Returns + ------- + dict + Bin edges for each unit. If concatenated is True, the bin edges are a 1D array. + If False, the bin edges are a list of arrays, one per segment. """ bin_edges_for_units = {} num_segments = len(segment_samples) @@ -31,27 +41,38 @@ def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, per for seg_index in range(num_segments): seg_periods = periods_unit[periods_unit["segment_index"] == seg_index] if len(seg_periods) == 0: + if not concatenated: + bin_edges.append(np.array([])) continue - seg_start = np.sum(segment_samples[:seg_index]) + seg_start = np.sum(segment_samples[:seg_index]) if concatenated else 0 + bin_edges_segment = [] for period in seg_periods: start_sample = seg_start + period["start_sample_index"] end_sample = seg_start + period["end_sample_index"] end_sample = end_sample // bin_duration_samples * bin_duration_samples + 1 # align to bin - bin_edges.extend(np.arange(start_sample, end_sample, bin_duration_samples)) - bin_edges_for_units[unit_id] = np.unique(np.array(bin_edges)) + bin_edges_segment.extend(np.arange(start_sample, end_sample, bin_duration_samples)) + bin_edges_segment = np.unique(np.array(bin_edges_segment)) + if concatenated: + bin_edges.extend(bin_edges_segment) + else: + bin_edges.append(bin_edges_segment) + bin_edges_for_units[unit_id] = bin_edges else: for unit_id in sorting.unit_ids: bin_edges = [] for seg_index in range(num_segments): - seg_start = np.sum(segment_samples[:seg_index]) + seg_start = np.sum(segment_samples[:seg_index]) if concatenated else 0 seg_end = seg_start + segment_samples[seg_index] # for segments which are not the last, we don't need to correct the end # since the first index of the next segment will be the end of the current segment if seg_index == num_segments - 1: seg_end = seg_end // bin_duration_samples * bin_duration_samples + 1 # align to bin - bins = np.arange(seg_start, seg_end, bin_duration_samples) - bin_edges.extend(bins) - bin_edges_for_units[unit_id] = np.array(bin_edges) + bin_edges_segment = np.arange(seg_start, seg_end, bin_duration_samples) + if concatenated: + bin_edges.extend(bin_edges_segment) + else: + bin_edges.append(bin_edges_segment) + bin_edges_for_units[unit_id] = bin_edges return bin_edges_for_units From 630c6622b0c77114158dec594368c1990128146b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 22 Jan 2026 11:08:04 +0100 Subject: [PATCH 38/40] rename compute_periods to compute_regular_periods --- .../quality/tests/test_metrics_functions.py | 14 +++++++------- .../spiketrain/tests/test_metric_functions.py | 6 +++--- src/spikeinterface/metrics/utils.py | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index c13f1ffbaa..cccd15f8a5 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -11,7 +11,7 @@ synthesize_random_firings, ) -from spikeinterface.metrics.utils import create_ground_truth_pc_distributions, compute_periods +from spikeinterface.metrics.utils import create_ground_truth_pc_distributions, create_regular_periods from spikeinterface.metrics.quality import ( get_quality_metric_list, @@ -102,14 +102,14 @@ def sorting_analyzer_violations(): @pytest.fixture def periods_simple(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple - periods = compute_periods(sorting_analyzer, num_periods=5) + periods = create_regular_periods(sorting_analyzer, num_periods=5) return periods @pytest.fixture def periods_violations(sorting_analyzer_violations): sorting_analyzer = sorting_analyzer_violations - periods = compute_periods(sorting_analyzer, num_periods=5) + periods = create_regular_periods(sorting_analyzer, num_periods=5) return periods @@ -269,7 +269,7 @@ def test_simplified_silhouette_score_metrics(): def test_calculate_firing_range(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple firing_ranges = compute_firing_ranges(sorting_analyzer, bin_size_s=1) - periods = compute_periods(sorting_analyzer, num_periods=5, bin_size_s=1) + periods = create_regular_periods(sorting_analyzer, num_periods=5, bin_size_s=1) firing_ranges_periods = compute_firing_ranges(sorting_analyzer, periods=periods, bin_size_s=1) assert firing_ranges == firing_ranges_periods @@ -288,7 +288,7 @@ def test_calculate_amplitude_cutoff(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple # spike_amps = sorting_analyzer.get_extension("spike_amplitudes").get_data() amp_cuts = compute_amplitude_cutoffs(sorting_analyzer, num_histogram_bins=10) - periods = compute_periods(sorting_analyzer, num_periods=5) + periods = create_regular_periods(sorting_analyzer, num_periods=5) amp_cuts_periods = compute_amplitude_cutoffs(sorting_analyzer, periods=periods, num_histogram_bins=10) assert amp_cuts == amp_cuts_periods @@ -306,7 +306,7 @@ def test_calculate_amplitude_median(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple # spike_amps = sorting_analyzer.get_extension("spike_amplitudes").get_data() amp_medians = compute_amplitude_medians(sorting_analyzer) - periods = compute_periods(sorting_analyzer, num_periods=5) + periods = create_regular_periods(sorting_analyzer, num_periods=5) amp_medians_periods = compute_amplitude_medians(sorting_analyzer, periods=periods) assert amp_medians == amp_medians_periods @@ -536,7 +536,7 @@ def test_calculate_drift_metrics(sorting_analyzer_simple): drifts_ptps, drifts_stds, drift_mads = compute_drift_metrics( sorting_analyzer, interval_s=10, min_spikes_per_interval=10 ) - periods = compute_periods(sorting_analyzer, num_periods=5, bin_size_s=10) + periods = create_regular_periods(sorting_analyzer, num_periods=5, bin_size_s=10) drifts_ptps_periods, drifts_stds_periods, drift_mads_periods = compute_drift_metrics( sorting_analyzer, periods=periods, min_spikes_per_interval=10, interval_s=10 ) diff --git a/src/spikeinterface/metrics/spiketrain/tests/test_metric_functions.py b/src/spikeinterface/metrics/spiketrain/tests/test_metric_functions.py index 86a5e9db2d..7577c767d6 100644 --- a/src/spikeinterface/metrics/spiketrain/tests/test_metric_functions.py +++ b/src/spikeinterface/metrics/spiketrain/tests/test_metric_functions.py @@ -1,7 +1,7 @@ import numpy as np from spikeinterface.core.base import unit_period_dtype -from spikeinterface.metrics.utils import compute_periods +from spikeinterface.metrics.utils import create_regular_periods from spikeinterface.metrics.spiketrain.metrics import compute_num_spikes, compute_firing_rates @@ -9,7 +9,7 @@ def test_calculate_num_spikes(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple # spike_amps = sorting_analyzer.get_extension("spike_amplitudes").get_data() num_spikes = compute_num_spikes(sorting_analyzer) - periods = compute_periods(sorting_analyzer, num_periods=5) + periods = create_regular_periods(sorting_analyzer, num_periods=5) num_spikes_periods = compute_num_spikes(sorting_analyzer, periods=periods) assert num_spikes == num_spikes_periods @@ -23,7 +23,7 @@ def test_calculate_firing_rates(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple # spike_amps = sorting_analyzer.get_extension("spike_amplitudes").get_data() firing_rates = compute_firing_rates(sorting_analyzer) - periods = compute_periods(sorting_analyzer, num_periods=5) + periods = create_regular_periods(sorting_analyzer, num_periods=5) firing_rates_periods = compute_firing_rates(sorting_analyzer, periods=periods) assert firing_rates == firing_rates_periods diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py index e007b19c05..0616c5bb7b 100644 --- a/src/spikeinterface/metrics/utils.py +++ b/src/spikeinterface/metrics/utils.py @@ -109,7 +109,7 @@ def compute_total_durations_per_unit(sorting_analyzer, periods=None): return total_durations -def compute_periods(sorting_analyzer, num_periods, bin_size_s=None): +def create_regular_periods(sorting_analyzer, num_periods, bin_size_s=None): """ Computes and sets periods for each unit in the sorting analyzer. The periods span the total duration of the recording, but divide it into From 1fd1fd4d216172c2eaddf491738af1f6d2d57d9a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 22 Jan 2026 11:14:01 +0100 Subject: [PATCH 39/40] Remove print --- src/spikeinterface/metrics/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py index 0616c5bb7b..df0ed7d1e9 100644 --- a/src/spikeinterface/metrics/utils.py +++ b/src/spikeinterface/metrics/utils.py @@ -134,7 +134,6 @@ def create_regular_periods(sorting_analyzer, num_periods, bin_size_s=None): samples_per_period = sorting_analyzer.get_num_samples(segment_index) // num_periods if bin_size_s is not None: bin_size_samples = int(bin_size_s * sorting_analyzer.sampling_frequency) - print(samples_per_period / bin_size_samples) samples_per_period = samples_per_period // bin_size_samples * bin_size_samples num_periods = int(np.round(sorting_analyzer.get_num_samples(segment_index) / samples_per_period)) for unit_index, unit_id in enumerate(sorting_analyzer.unit_ids): From f0d0ba7b9fe12d5316cd3a337439b9bc97872475 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 22 Jan 2026 11:20:06 +0100 Subject: [PATCH 40/40] Speed up function which was already fast but Sam didn't like it --- src/spikeinterface/metrics/utils.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py index 08fb00645f..b7b94294eb 100644 --- a/src/spikeinterface/metrics/utils.py +++ b/src/spikeinterface/metrics/utils.py @@ -93,15 +93,12 @@ def compute_total_samples_per_unit(sorting_analyzer, periods=None): Total number of samples for each unit. """ if periods is not None: - total_samples = {} + total_samples_array = np.zeros(len(sorting_analyzer.unit_ids), dtype="int64") sorting = sorting_analyzer.sorting - for unit_id in sorting.unit_ids: - unit_index = sorting.id_to_index(unit_id) - periods_unit = periods[periods["unit_index"] == unit_index] - num_samples_in_period = 0 - for period in periods_unit: - num_samples_in_period += period["end_sample_index"] - period["start_sample_index"] - total_samples[unit_id] = num_samples_in_period + for period in periods: + unit_index = period["unit_index"] + total_samples_array[unit_index] += period["end_sample_index"] - period["start_sample_index"] + total_samples = dict(zip(sorting.unit_ids, total_samples_array)) else: total = sorting_analyzer.get_total_samples() total_samples = {unit_id: total for unit_id in sorting_analyzer.unit_ids}