diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index a59b040b3b..30038bc270 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -13,6 +13,7 @@ import numpy as np from collections import namedtuple +from .numpyextractors import NumpySorting from .sortinganalyzer import SortingAnalyzer, AnalyzerExtension, register_result_extension from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_with_accumulator from .recording_tools import get_noise_levels @@ -823,10 +824,9 @@ class BaseMetric: metric_columns = {} # column names and their dtypes of the dataframe metric_descriptions = {} # descriptions of each metric column needs_recording = False # whether the metric needs recording - needs_tmp_data = ( - False # whether the metric needs temporary data comoputed with _prepare_data at the MetricExtension level - ) - needs_job_kwargs = False + needs_tmp_data = False # whether the metric needs temporary data computed with MetricExtension._prepare_data + needs_job_kwargs = False # whether the metric needs job_kwargs + supports_periods = False # whether the metric function supports periods depend_on = [] # extensions the metric depends on # the metric function must have the signature: @@ -839,7 +839,7 @@ class BaseMetric: metric_function = None # to be defined in subclass @classmethod - def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs): + def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs, periods=None): """Compute the metric. Parameters @@ -854,6 +854,8 @@ def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs Temporary data to pass to the metric function job_kwargs : dict Job keyword arguments to control parallelization + periods : np.ndarray | None + Numpy array of unit periods of unit_period_dtype if supports_periods is True Returns ------- @@ -865,6 +867,8 @@ def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs args += (tmp_data,) if cls.needs_job_kwargs: args += (job_kwargs,) + if cls.supports_periods: + args += (periods,) results = cls.metric_function(*args, **metric_params) @@ -897,6 +901,17 @@ class BaseMetricExtension(AnalyzerExtension): need_backward_compatibility_on_load = False metric_list: list[BaseMetric] = None # list of BaseMetric + @classmethod + def get_available_metric_names(cls): + """Get the available metric names. + + Returns + ------- + available_metric_names : list[str] + List of available metric names. + """ + return [m.metric_name for m in cls.metric_list] + @classmethod def get_default_metric_params(cls): """Get the default metric parameters. @@ -988,6 +1003,7 @@ def _set_params( metric_params: dict | None = None, delete_existing_metrics: bool = False, metrics_to_compute: list[str] | None = None, + periods: np.ndarray | None = None, **other_params, ): """ @@ -1004,6 +1020,8 @@ def _set_params( If True, existing metrics in the extension will be deleted before computing new ones. metrics_to_compute : list[str] | None List of metric names to compute. If None, all metrics in `metric_names` are computed. + periods : np.ndarray | None + Numpy array of unit_period_dtype defining periods to compute metrics over. other_params : dict Additional parameters for metric computation. @@ -1079,6 +1097,7 @@ def _set_params( metrics_to_compute=metrics_to_compute, delete_existing_metrics=delete_existing_metrics, metric_params=metric_params, + periods=periods, **other_params, ) return params @@ -1129,6 +1148,8 @@ def _compute_metrics( if metric_names is None: metric_names = self.params["metric_names"] + periods = self.params.get("periods", None) + column_names_dtypes = {} for metric_name in metric_names: metric = [m for m in self.metric_list if m.metric_name == metric_name][0] @@ -1153,6 +1174,7 @@ def _compute_metrics( metric_params=metric_params, tmp_data=tmp_data, job_kwargs=job_kwargs, + periods=periods, ) except Exception as e: warnings.warn(f"Error computing metric {metric_name}: {e}") @@ -1179,6 +1201,7 @@ def _run(self, **job_kwargs): metrics_to_compute = self.params["metrics_to_compute"] delete_existing_metrics = self.params["delete_existing_metrics"] + periods = self.params.get("periods", None) _, job_kwargs = split_job_kwargs(job_kwargs) job_kwargs = fix_job_kwargs(job_kwargs) @@ -1452,6 +1475,16 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, periods, ) all_data = all_data[keep_mask] + # since we have the mask already, we can use it directly to avoid double computation + spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=True) + sliced_spike_vector = spike_vector[keep_mask] + sorting = NumpySorting( + sliced_spike_vector, + sampling_frequency=self.sorting_analyzer.sampling_frequency, + unit_ids=self.sorting_analyzer.unit_ids, + ) + else: + sorting = self.sorting_analyzer.sorting if outputs == "numpy": if copy: @@ -1460,10 +1493,10 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, return all_data elif outputs == "by_unit": unit_ids = self.sorting_analyzer.unit_ids - spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) + if keep_mask is not None: # since we are filtering spikes, we need to recompute the spike indices - spike_vector = spike_vector[keep_mask] + spike_vector = sorting.to_spike_vector(concatenated=False) spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) else: # use the cache of indices diff --git a/src/spikeinterface/metrics/conftest.py b/src/spikeinterface/metrics/conftest.py index 8d32c103fa..5313e763c1 100644 --- a/src/spikeinterface/metrics/conftest.py +++ b/src/spikeinterface/metrics/conftest.py @@ -1,8 +1,85 @@ import pytest -from spikeinterface.postprocessing.tests.conftest import _small_sorting_analyzer +from spikeinterface.core import ( + generate_ground_truth_recording, + create_sorting_analyzer, +) + +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") + + +def make_small_analyzer(): + recording, sorting = generate_ground_truth_recording( + durations=[10.0], + num_units=10, + seed=1205, + ) + + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + + sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"]) + + sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") + + extensions_to_compute = { + "random_spikes": {"seed": 1205}, + "noise_levels": {"seed": 1205}, + "waveforms": {}, + "templates": {"operators": ["average", "median"]}, + "spike_amplitudes": {}, + "spike_locations": {}, + "principal_components": {}, + } + + sorting_analyzer.compute(extensions_to_compute) + + return sorting_analyzer @pytest.fixture(scope="module") def small_sorting_analyzer(): - return _small_sorting_analyzer() + return make_small_analyzer() + + +@pytest.fixture(scope="module") +def sorting_analyzer_simple(): + # we need high firing rate for amplitude_cutoff + recording, sorting = generate_ground_truth_recording( + durations=[ + 120.0, + ], + sampling_frequency=30_000.0, + num_channels=6, + num_units=10, + generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), + generate_unit_locations_kwargs=dict( + margin_um=5.0, + minimum_z=5.0, + maximum_z=20.0, + ), + generate_templates_kwargs=dict( + unit_params=dict( + alpha=(200.0, 500.0), + ) + ), + noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + seed=1205, + ) + + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) + + sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205) + sorting_analyzer.compute("noise_levels") + sorting_analyzer.compute("waveforms", **job_kwargs) + sorting_analyzer.compute("templates") + sorting_analyzer.compute(["spike_amplitudes", "spike_locations"], **job_kwargs) + + return sorting_analyzer diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index f985ea6e23..198e98037c 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -18,14 +18,18 @@ from spikeinterface.core.analyzer_extension_core import BaseMetric from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs -from spikeinterface.core import SortingAnalyzer, get_noise_levels +from spikeinterface.core import SortingAnalyzer, get_noise_levels, NumpySorting from spikeinterface.core.template_tools import ( get_template_extremum_channel, get_template_extremum_amplitude, get_dense_templates_array, ) - -from ..spiketrain.metrics import NumSpikes, FiringRate +from spikeinterface.metrics.spiketrain.metrics import NumSpikes, FiringRate +from spikeinterface.metrics.utils import ( + compute_bin_edges_per_unit, + compute_total_durations_per_unit, + compute_total_samples_per_unit, +) numba_spec = importlib.util.find_spec("numba") if numba_spec is not None: @@ -34,7 +38,9 @@ HAVE_NUMBA = False -def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0): +def compute_presence_ratios( + sorting_analyzer, unit_ids=None, periods=None, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0 +): """ Calculate the presence ratio, the fraction of time the unit is firing above a certain threshold. @@ -44,6 +50,9 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 A SortingAnalyzer object. unit_ids : list or None The list of unit ids to compute the presence ratio. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. bin_duration_s : float, default: 60 The duration of each bin in seconds. If the duration is less than this value, presence_ratio is set to NaN. @@ -62,16 +71,23 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 To do so, spike trains across segments are concatenated to mimic a continuous segment. """ sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() + num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) - total_length = sorting_analyzer.get_total_samples() - total_duration = sorting_analyzer.get_total_duration() + segment_samples = [sorting_analyzer.get_num_samples(i) for i in range(num_segs)] + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) + total_samples = np.sum(segment_samples) bin_duration_samples = int((bin_duration_s * sorting_analyzer.sampling_frequency)) - num_bin_edges = total_length // bin_duration_samples + 1 - bin_edges = np.arange(num_bin_edges) * bin_duration_samples - seg_lengths = [sorting_analyzer.get_num_samples(i) for i in range(num_segs)] + bin_edges_per_unit = compute_bin_edges_per_unit( + sorting, + segment_samples=segment_samples, + periods=periods, + bin_duration_s=bin_duration_s, + ) + mean_fr_ratio_thresh = float(mean_fr_ratio_thresh) if mean_fr_ratio_thresh < 0: raise ValueError( @@ -81,7 +97,7 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 warnings.warn("`mean_fr_ratio_thres` parameter above 1 might lead to low presence ratios.") presence_ratios = {} - if total_length < bin_duration_samples: + if total_samples < bin_duration_samples: warnings.warn( f"Bin duration of {bin_duration_s}s is larger than recording duration. " f"Presence ratios are set to NaN." ) @@ -89,10 +105,20 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 else: for unit_id in unit_ids: + if num_spikes[unit_id] == 0: + presence_ratios[unit_id] = np.nan + continue + spike_train = [] + bin_edges = bin_edges_per_unit[unit_id] + if len(bin_edges) < 2: + presence_ratios[unit_id] = 0.0 + continue + total_duration = total_durations[unit_id] + spike_train = [] for segment_index in range(num_segs): st = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - st = st + np.sum(seg_lengths[:segment_index]) + st = st + np.sum(segment_samples[:segment_index]) spike_train.append(st) spike_train = np.concatenate(spike_train) @@ -101,7 +127,6 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 presence_ratios[unit_id] = presence_ratio( spike_train, - total_length, bin_edges=bin_edges, bin_n_spikes_thres=bin_n_spikes_thres, ) @@ -115,6 +140,7 @@ class PresenceRatio(BaseMetric): metric_params = {"bin_duration_s": 60, "mean_fr_ratio_thresh": 0.0} metric_columns = {"presence_ratio": float} metric_descriptions = {"presence_ratio": "Fraction of time the unit is active."} + supports_periods = True def compute_snrs( @@ -181,7 +207,7 @@ class SNR(BaseMetric): depend_on = ["noise_levels", "templates"] -def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5, min_isi_ms=0): +def compute_isi_violations(sorting_analyzer, unit_ids=None, periods=None, isi_threshold_ms=1.5, min_isi_ms=0): """ Calculate Inter-Spike Interval (ISI) violations. @@ -196,6 +222,9 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 The SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the ISI violations. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. isi_threshold_ms : float, default: 1.5 Threshold for classifying adjacent spikes as an ISI violation, in ms. This is the biophysical refractory period. @@ -234,10 +263,12 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 res = namedtuple("isi_violation", ["isi_violations_ratio", "isi_violations_count"]) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids - total_duration_s = sorting_analyzer.get_total_duration() + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) + num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) fs = sorting_analyzer.sampling_frequency isi_threshold_s = isi_threshold_ms / 1000 @@ -247,16 +278,19 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 isi_violations_ratio = {} for unit_id in unit_ids: + if num_spikes[unit_id] == 0: + isi_violations_ratio[unit_id] = np.nan + isi_violations_count[unit_id] = -1 + continue + spike_train_list = [] for segment_index in range(sorting_analyzer.get_num_segments()): spike_train = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) if len(spike_train) > 0: spike_train_list.append(spike_train / fs) - if not any([len(train) > 0 for train in spike_train_list]): - continue - - ratio, _, count = isi_violations(spike_train_list, total_duration_s, isi_threshold_s, min_isi_s) + total_duration = total_durations[unit_id] + ratio, _, count = isi_violations(spike_train_list, total_duration, isi_threshold_s, min_isi_s) isi_violations_ratio[unit_id] = ratio isi_violations_count[unit_id] = count @@ -273,10 +307,11 @@ class ISIViolation(BaseMetric): "isi_violations_ratio": "Ratio of ISI violations for each unit.", "isi_violations_count": "Count of ISI violations for each unit.", } + supports_periods = True def compute_refrac_period_violations( - sorting_analyzer, unit_ids=None, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0 + sorting_analyzer, unit_ids=None, periods=None, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0 ): """ Calculate the number of refractory period violations. @@ -291,6 +326,9 @@ def compute_refrac_period_violations( The SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the refractory period violations. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. refractory_period_ms : float, default: 1.0 The period (in ms) where no 2 good spikes can occur. censored_period_ms : float, default: 0.0 @@ -318,44 +356,48 @@ def compute_refrac_period_violations( ---------- Based on metrics described in [Llobet]_ """ - res = namedtuple("rp_violations", ["rp_contamination", "rp_violations"]) + sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) if unit_ids is None: - unit_ids = sorting_analyzer.unit_ids + unit_ids = sorting.unit_ids if not HAVE_NUMBA: warnings.warn("Error: numba is not installed.") warnings.warn("compute_refrac_period_violations cannot run without numba.") - return {unit_id: np.nan for unit_id in unit_ids} + return res({unit_id: np.nan for unit_id in unit_ids}, {unit_id: 0 for unit_id in unit_ids}) - sorting = sorting_analyzer.sorting - fs = sorting_analyzer.sampling_frequency num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) + fs = sorting_analyzer.sampling_frequency t_c = int(round(censored_period_ms * fs * 1e-3)) t_r = int(round(refractory_period_ms * fs * 1e-3)) - nb_rp_violations = {} - T = sorting_analyzer.get_total_samples() + total_samples = compute_total_samples_per_unit(sorting_analyzer, periods=periods) nb_violations = {} rp_contamination = {} - for unit_id in unit_ids: - nb_rp_violations[unit_id] = 0 + if num_spikes[unit_id] == 0: + rp_contamination[unit_id] = np.nan + nb_violations[unit_id] = -1 + continue + + nb_violations[unit_id] = 0 + total_samples_unit = total_samples[unit_id] + for segment_index in range(sorting_analyzer.get_num_segments()): spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - nb_rp_violations[unit_id] += _compute_rp_violations_numba(spike_times, t_c, t_r) - - n_v = nb_rp_violations[unit_id] - nb_violations[unit_id] = n_v - N = num_spikes[unit_id] - if N == 0: - rp_contamination[unit_id] = np.nan - else: - D = 1 - n_v * (T - 2 * N * t_c) / (N**2 * (t_r - t_c)) - rp_contamination[unit_id] = 1 - math.sqrt(D) if D >= 0 else 1.0 + nb_violations[unit_id] += _compute_rp_violations_numba(spike_times, t_c, t_r) + + rp_contamination[unit_id] = _compute_rp_contamination_one_unit( + nb_violations[unit_id], + num_spikes[unit_id], + total_samples_unit, + t_c, + t_r, + ) return res(rp_contamination, nb_violations) @@ -369,11 +411,13 @@ class RPViolation(BaseMetric): "rp_contamination": "Refractory period contamination described in Llobet & Wyngaard 2022.", "rp_violations": "Number of refractory period violations.", } + supports_periods = True def compute_sliding_rp_violations( sorting_analyzer, unit_ids=None, + periods=None, min_spikes=0, bin_size_ms=0.25, window_size_s=1, @@ -392,6 +436,9 @@ def compute_sliding_rp_violations( A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the sliding RP violations. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. min_spikes : int, default: 0 Contamination is set to np.nan if the unit has less than this many spikes across all segments. @@ -417,8 +464,10 @@ def compute_sliding_rp_violations( This code was adapted from: https://github.com/SteinmetzLab/slidingRefractory/blob/1.0.0/python/slidingRP/metrics.py """ - duration = sorting_analyzer.get_total_duration() + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) + if unit_ids is None: unit_ids = sorting_analyzer.unit_ids @@ -442,7 +491,7 @@ def compute_sliding_rp_violations( contamination[unit_id] = np.nan continue - from spikeinterface.core.numpyextractors import NumpySorting + duration = total_durations[unit_id] sub_sorting = NumpySorting(sub_spikes, fs, unit_ids=[unit_id]) @@ -474,9 +523,10 @@ class SlidingRPViolation(BaseMetric): metric_descriptions = { "sliding_rp_violation": "Minimum contamination at 90% confidence using sliding refractory period method." } + supports_periods = True -def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None): +def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, periods=None, synchrony_sizes=None): """ Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of spikes at the exact same sample index, with synchrony sizes 2, 4 and 8. @@ -487,6 +537,9 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N A SortingAnalyzer object. unit_ids : list or None, default: None List of unit ids to compute the synchrony metrics. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. synchrony_sizes: None, default: None Deprecated argument. Please use private `_get_synchrony_counts` if you need finer control over number of synchronous spikes. @@ -510,11 +563,12 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes]) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting.unit_ids - spike_counts = sorting_analyzer.sorting.count_num_spikes_per_unit(unit_ids=unit_ids) + num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) spikes = sorting.to_spike_vector() all_unit_ids = sorting.unit_ids @@ -527,10 +581,10 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N for i, unit_id in enumerate(all_unit_ids): if unit_id not in unit_ids: continue - if spike_counts[unit_id] != 0: - sync_id_metrics_dict[unit_id] = synchrony_counts[sync_idx][i] / spike_counts[unit_id] + if num_spikes[unit_id] != 0: + sync_id_metrics_dict[unit_id] = synchrony_counts[sync_idx][i] / num_spikes[unit_id] else: - sync_id_metrics_dict[unit_id] = 0 + sync_id_metrics_dict[unit_id] = -1 synchrony_metrics_dict[f"sync_spike_{synchrony_size}"] = sync_id_metrics_dict return res(**synchrony_metrics_dict) @@ -545,9 +599,10 @@ class Synchrony(BaseMetric): "sync_spike_4": "Fraction of spikes that are synchronous with at least three other spikes.", "sync_spike_8": "Fraction of spikes that are synchronous with at least seven other spikes.", } + supports_periods = True -def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percentiles=(5, 95)): +def compute_firing_ranges(sorting_analyzer, unit_ids=None, periods=None, bin_size_s=5, percentiles=(5, 95)): """ Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution computed in non-overlapping time bins. @@ -558,6 +613,9 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the firing range. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. bin_size_s : float, default: 5 The size of the bin in seconds. percentiles : tuple, default: (5, 95) @@ -575,37 +633,59 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent sampling_frequency = sorting_analyzer.sampling_frequency bin_size_samples = int(bin_size_s * sampling_frequency) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) + segment_samples = [ + sorting_analyzer.get_num_samples(segment_index) for segment_index in range(sorting_analyzer.get_num_segments()) + ] + if unit_ids is None: unit_ids = sorting.unit_ids - if all( - [ - sorting_analyzer.get_num_samples(segment_index) < bin_size_samples - for segment_index in range(sorting_analyzer.get_num_segments()) - ] - ): - warnings.warn(f"Bin size of {bin_size_s}s is larger than each segment duration. Firing ranges are set to NaN.") - return {unit_id: np.nan for unit_id in unit_ids} + num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) + total_samples = compute_total_samples_per_unit(sorting_analyzer, periods=periods) # for each segment, we compute the firing rate histogram and we concatenate them firing_rate_histograms = {unit_id: np.array([], dtype=float) for unit_id in unit_ids} + bin_edges_per_unit = compute_bin_edges_per_unit( + sorting, + segment_samples=segment_samples, + periods=periods, + bin_duration_s=bin_size_s, + ) + cumulative_segment_samples = np.cumsum([0] + segment_samples[:-1]) for unit_id in unit_ids: - firing_histograms = [] + if num_spikes[unit_id] == 0 or total_samples[unit_id] < bin_size_samples: + continue + bin_edges = bin_edges_per_unit[unit_id] + + # we can concatenate spike trains across segments adding the cumulative number of samples + # as offset, since bin edges are already cumulative + spike_trains = [] for segment_index in range(sorting_analyzer.get_num_segments()): - num_samples = sorting_analyzer.get_num_samples(segment_index) - edges = np.arange(0, num_samples + 1, bin_size_samples) - spike_times = sorting.get_unit_spike_train(unit_id, segment_index) - spike_counts, _ = np.histogram(spike_times, bins=edges) - firing_rates = spike_counts / bin_size_s - firing_histograms += [firing_rates] - firing_rate_histograms[unit_id] = np.concatenate(firing_histograms) + spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + spike_times = spike_times + cumulative_segment_samples[segment_index] + spike_trains.append(spike_times) + spike_train = np.concatenate(spike_trains, dtype="int64") + + spike_counts, _ = np.histogram(spike_train, bins=bin_edges) + firing_rate_histograms[unit_id] = spike_counts / bin_size_s # finally we compute the percentiles firing_ranges = {} + failed_units = [] for unit_id in unit_ids: + if num_spikes[unit_id] == 0 or total_samples[unit_id] < bin_size_samples: + failed_units.append(unit_id) + firing_ranges[unit_id] = np.nan + continue firing_ranges[unit_id] = np.percentile(firing_rate_histograms[unit_id], percentiles[1]) - np.percentile( firing_rate_histograms[unit_id], percentiles[0] ) + if len(failed_units) > 0: + warnings.warn( + f"Firing range could not be computed for units {failed_units} " + f"because they have no spikes or the total duration is less than bin size." + ) return firing_ranges @@ -618,11 +698,13 @@ class FiringRange(BaseMetric): metric_descriptions = { "firing_range": "Range between the percentiles (default: 5th and 95th) of the firing rates distribution." } + supports_periods = True def compute_amplitude_cv_metrics( sorting_analyzer, unit_ids=None, + periods=None, average_num_spikes_per_bin=50, percentiles=(5, 95), min_num_bins=10, @@ -639,6 +721,9 @@ def compute_amplitude_cv_metrics( A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the amplitude spread. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. average_num_spikes_per_bin : int, default: 50 The average number of spikes per bin. This is used to estimate a temporal bin size using the firing rate of each unit. For example, if a unit has a firing rate of 10 Hz, amd the average number of spikes per bin is @@ -668,17 +753,24 @@ def compute_amplitude_cv_metrics( "spike_amplitudes", "amplitude_scalings", ), "Invalid amplitude_extension. It can be either 'spike_amplitudes' or 'amplitude_scalings'" - sorting = sorting_analyzer.sorting - total_duration = sorting_analyzer.get_total_duration() - if unit_ids is None: - unit_ids = sorting.unit_ids + unit_ids = sorting_analyzer.unit_ids + sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) - num_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit(unit_ids=unit_ids, outputs="dict") - amps = sorting_analyzer.get_extension(amplitude_extension).get_data(outputs="by_unit", concatenated=False) + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) + num_spikes = sorting.count_num_spikes_per_unit(outputs="dict", unit_ids=unit_ids) + amps = sorting_analyzer.get_extension(amplitude_extension).get_data( + outputs="by_unit", concatenated=False, periods=periods + ) amplitude_cv_medians, amplitude_cv_ranges = {}, {} for unit_id in unit_ids: + if num_spikes[unit_id] == 0: + amplitude_cv_medians[unit_id] = np.nan + amplitude_cv_ranges[unit_id] = np.nan + continue + total_duration = total_durations[unit_id] firing_rate = num_spikes[unit_id] / total_duration temporal_bin_size_samples = int( (average_num_spikes_per_bin / firing_rate) * sorting_analyzer.sampling_frequency @@ -723,12 +815,14 @@ class AmplitudeCV(BaseMetric): "amplitude_cv_median": "Median of the coefficient of variation of spike amplitudes within temporal bins.", "amplitude_cv_range": "Range of the coefficient of variation of spike amplitudes within temporal bins.", } + supports_periods = True depend_on = ["spike_amplitudes|amplitude_scalings"] def compute_amplitude_cutoffs( sorting_analyzer, unit_ids=None, + periods=None, num_histogram_bins=500, histogram_smoothing_value=3, amplitudes_bins_min_ratio=5, @@ -742,6 +836,9 @@ def compute_amplitude_cutoffs( A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the amplitude cutoffs. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. num_histogram_bins : int, default: 100 The number of bins to use to compute the amplitude histogram. histogram_smoothing_value : int, default: 3 @@ -785,14 +882,13 @@ def compute_amplitude_cutoffs( invert_amplitudes = True extension = sorting_analyzer.get_extension("amplitude_scalings") - amplitudes_by_units = extension.get_data(concatenated=True, outputs="by_unit") + amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, periods=periods) for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] if invert_amplitudes: amplitudes = -amplitudes - all_fraction_missing[unit_id] = amplitude_cutoff( amplitudes, num_histogram_bins, histogram_smoothing_value, amplitudes_bins_min_ratio ) @@ -815,10 +911,11 @@ class AmplitudeCutoff(BaseMetric): metric_descriptions = { "amplitude_cutoff": "Estimated fraction of missing spikes, based on the amplitude distribution." } + supports_periods = True depend_on = ["spike_amplitudes|amplitude_scalings"] -def compute_amplitude_medians(sorting_analyzer, unit_ids=None): +def compute_amplitude_medians(sorting_analyzer, unit_ids=None, periods=None): """ Compute median of the amplitude distributions (in absolute value). @@ -828,6 +925,9 @@ def compute_amplitude_medians(sorting_analyzer, unit_ids=None): A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the amplitude medians. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -846,8 +946,7 @@ def compute_amplitude_medians(sorting_analyzer, unit_ids=None): all_amplitude_medians = {} amplitude_extension = sorting_analyzer.get_extension("spike_amplitudes") - amplitudes_by_units = amplitude_extension.get_data(outputs="by_unit", concatenated=True) - + amplitudes_by_units = amplitude_extension.get_data(outputs="by_unit", concatenated=True, periods=periods) for unit_id in unit_ids: all_amplitude_medians[unit_id] = np.median(amplitudes_by_units[unit_id]) @@ -861,10 +960,13 @@ class AmplitudeMedian(BaseMetric): metric_descriptions = { "amplitude_median": "Median of the amplitude distributions (in absolute value) for each unit in uV." } + supports_periods = True depend_on = ["spike_amplitudes"] -def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, low_quantile=0.1, n_bins=100): +def compute_noise_cutoffs( + sorting_analyzer, unit_ids=None, periods=None, high_quantile=0.25, low_quantile=0.1, n_bins=100 +): """ A metric to determine if a unit's amplitude distribution is cut off as it approaches zero, without assuming a Gaussian distribution. @@ -882,6 +984,9 @@ def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, l A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the amplitude cutoffs. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. high_quantile : float, default: 0.25 Quantile of the amplitude range above which values are treated as "high" (e.g. 0.25 = top 25%), the reference region. low_quantile : int, default: 0.1 @@ -916,10 +1021,14 @@ def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, l invert_amplitudes = True extension = sorting_analyzer.get_extension("amplitude_scalings") - amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True) + amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, periods=periods) for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] + if len(amplitudes) == 0: + cutoff, ratio = np.nan, np.nan + continue + if invert_amplitudes: amplitudes = -amplitudes @@ -942,12 +1051,14 @@ class NoiseCutoff(BaseMetric): ), "noise_ratio": "Ratio of counts in the lower-amplitude bins to the count in the highest bin.", } + supports_periods = True depend_on = ["spike_amplitudes|amplitude_scalings"] def compute_drift_metrics( sorting_analyzer, unit_ids=None, + periods=None, interval_s=60, min_spikes_per_interval=100, direction="y", @@ -975,6 +1086,9 @@ def compute_drift_metrics( A SortingAnalyzer object. unit_ids : list or None, default: None List of unit ids to compute the drift metrics. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. interval_s : int, default: 60 Interval length is seconds for computing spike depth. min_spikes_per_interval : int, default: 100 @@ -1010,29 +1124,25 @@ def compute_drift_metrics( check_has_required_extensions("drift", sorting_analyzer) res = namedtuple("drift_metrics", ["drift_ptp", "drift_std", "drift_mad"]) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting.unit_ids spike_locations_ext = sorting_analyzer.get_extension("spike_locations") - spike_locations_by_unit_and_segments = spike_locations_ext.get_data(outputs="by_unit", concatenated=False) - spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit", concatenated=True) + spike_locations_by_unit_and_segments = spike_locations_ext.get_data( + outputs="by_unit", concatenated=False, periods=periods + ) + spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit", concatenated=True, periods=periods) - interval_samples = int(interval_s * sorting_analyzer.sampling_frequency) + segment_samples = [sorting_analyzer.get_num_samples(i) for i in range(sorting_analyzer.get_num_segments())] data = spike_locations_by_unit[unit_ids[0]] assert direction in data.dtype.names, ( f"Direction {direction} is invalid. Available directions: " f"{data.dtype.names}" ) - total_duration = sorting_analyzer.get_total_duration() - if total_duration < min_num_bins * interval_s: - warnings.warn( - "The recording is too short given the specified 'interval_s' and " - "'min_num_bins'. Drift metrics will be set to NaN" - ) - empty_dict = {unit_id: np.nan for unit_id in unit_ids} - if return_positions: - return res(empty_dict, empty_dict, empty_dict), np.nan - else: - return res(empty_dict, empty_dict, empty_dict) + bin_edges_for_units = compute_bin_edges_per_unit( + sorting, segment_samples=segment_samples, periods=periods, bin_duration_s=interval_s, concatenated=False + ) + failed_units = [] # we need drift_ptps = {} @@ -1040,40 +1150,44 @@ def compute_drift_metrics( drift_mads = {} # reference positions are the medians across segments - reference_positions = np.zeros(len(unit_ids)) - median_position_segments = None + reference_positions = {} + median_position_segments = {unit_id: np.array([]) for unit_id in unit_ids} - for i, unit_id in enumerate(unit_ids): - reference_positions[i] = np.median(spike_locations_by_unit[unit_id][direction]) + for unit_id in unit_ids: + reference_positions[unit_id] = np.median(spike_locations_by_unit[unit_id][direction]) for segment_index in range(sorting_analyzer.get_num_segments()): - seg_length = sorting_analyzer.get_num_samples(segment_index) - num_bin_edges = seg_length // interval_samples + 1 - bins = np.arange(num_bin_edges) * interval_samples - median_positions = np.nan * np.zeros((len(unit_ids), num_bin_edges - 1)) - for i, unit_id in enumerate(unit_ids): + for unit_id in unit_ids: + bins = bin_edges_for_units[unit_id][segment_index] + num_bin_edges = len(bins) + if (num_bin_edges - 1) < min_num_bins: + failed_units.append(unit_id) + continue + median_positions = np.nan * np.zeros((num_bin_edges - 1)) spikes_in_segment_of_unit = sorting.get_unit_spike_train(unit_id, segment_index) bounds = np.searchsorted(spikes_in_segment_of_unit, bins, side="left") for bin_index, (i0, i1) in enumerate(zip(bounds[:-1], bounds[1:])): spike_locations_in_bin = spike_locations_by_unit_and_segments[segment_index][unit_id][i0:i1][direction] if (i1 - i0) >= min_spikes_per_interval: - median_positions[i, bin_index] = np.median(spike_locations_in_bin) - - if median_position_segments is None: - median_position_segments = median_positions - else: - median_position_segments = np.hstack((median_position_segments, median_positions)) + median_positions[bin_index] = np.median(spike_locations_in_bin) + median_position_segments[unit_id] = np.concatenate((median_position_segments[unit_id], median_positions)) # finally, compute deviations and drifts - position_diffs = median_position_segments - reference_positions[:, None] - for i, unit_id in enumerate(unit_ids): - position_diff = position_diffs[i] + for unit_id in unit_ids: + # Skip units that already failed because not enough bins in at least one segment + if unit_id in failed_units: + drift_ptps[unit_id] = np.nan + drift_stds[unit_id] = np.nan + drift_mads[unit_id] = np.nan + continue + position_diff = median_position_segments[unit_id] - reference_positions[unit_id] if np.any(np.isnan(position_diff)): # deal with nans: if more than 50% nans --> set to nan if np.sum(np.isnan(position_diff)) > min_fraction_valid_intervals * len(position_diff): ptp_drift = np.nan std_drift = np.nan mad_drift = np.nan + failed_units.append(unit_id) else: ptp_drift = np.nanmax(position_diff) - np.nanmin(position_diff) std_drift = np.nanstd(np.abs(position_diff)) @@ -1085,6 +1199,13 @@ def compute_drift_metrics( drift_ptps[unit_id] = ptp_drift drift_stds[unit_id] = std_drift drift_mads[unit_id] = mad_drift + + if len(failed_units) > 0: + warnings.warn( + f"Drift metrics could not be computed for units {failed_units} because they have less than " + f"{min_num_bins} bins given the specified 'interval_s' and 'min_num_bins' or not enough valid intervals." + ) + if return_positions: outs = res(drift_ptps, drift_stds, drift_mads), median_positions else: @@ -1107,12 +1228,14 @@ class Drift(BaseMetric): "drift_std": "Standard deviation of the drift signal in um.", "drift_mad": "Median absolute deviation of the drift signal in um.", } + supports_periods = True depend_on = ["spike_locations"] def compute_sd_ratio( sorting_analyzer: SortingAnalyzer, unit_ids=None, + periods=None, censored_period_ms: float = 4.0, correct_for_drift: bool = True, correct_for_template_itself: bool = True, @@ -1131,6 +1254,9 @@ def compute_sd_ratio( A SortingAnalyzer object. unit_ids : list or None, default: None The list of unit ids to compute this metric. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. censored_period_ms : float, default: 4.0 The censored period in milliseconds. This is to remove any potential bursts that could affect the SD. correct_for_drift : bool, default: True @@ -1154,11 +1280,14 @@ def compute_sd_ratio( job_kwargs = fix_job_kwargs(job_kwargs) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) censored_period = int(round(censored_period_ms * 1e-3 * sorting_analyzer.sampling_frequency)) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids + num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) + if not sorting_analyzer.has_recording(): warnings.warn( "The `sd_ratio` metric cannot work with a recordless SortingAnalyzer object" @@ -1167,7 +1296,7 @@ def compute_sd_ratio( return {unit_id: np.nan for unit_id in unit_ids} spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data( - outputs="by_unit", concatenated=False + outputs="by_unit", concatenated=False, periods=periods ) if not HAVE_NUMBA: @@ -1189,6 +1318,9 @@ def compute_sd_ratio( sd_ratio = {} for unit_id in unit_ids: + if num_spikes[unit_id] == 0: + sd_ratio[unit_id] = np.nan + continue spk_amp = [] for segment_index in range(sorting_analyzer.get_num_segments()): spike_train = sorting.get_unit_spike_train(unit_id, segment_index) @@ -1250,6 +1382,7 @@ class SDRatio(BaseMetric): "sd_ratio": "Ratio between the standard deviation of spike amplitudes and the standard deviation of noise." } needs_recording = True + supports_periods = True depend_on = ["templates", "spike_amplitudes"] @@ -1294,7 +1427,7 @@ def check_has_required_extensions(metric_name, sorting_analyzer): ### LOW-LEVEL FUNCTIONS ### -def presence_ratio(spike_train, total_length, bin_edges=None, num_bin_edges=None, bin_n_spikes_thres=0): +def presence_ratio(spike_train, bin_edges=None, num_bin_edges=None, bin_n_spikes_thres=0): """ Calculate the presence ratio for a single unit. @@ -1302,8 +1435,6 @@ def presence_ratio(spike_train, total_length, bin_edges=None, num_bin_edges=None ---------- spike_train : np.ndarray Spike times for this unit, in samples. - total_length : int - Total length of the recording in samples. bin_edges : np.array, optional Pre-computed bin edges (mutually exclusive with num_bin_edges). num_bin_edges : int, optional @@ -1531,6 +1662,46 @@ def slidingRP_violations( return min_cont_with_90_confidence +def _compute_rp_contamination_one_unit( + n_v, + n_spikes, + total_samples, + t_c, + t_r, +): + """ + Compute the refractory period contamination for one unit. + + Parameters + ---------- + n_v : int + Number of refractory period violations. + n_spikes : int + Number of spikes for the unit. + total_samples : int + Total number of samples in the recording. + t_c : int + Censored period in samples. + t_r : int + Refractory period in samples. + + Returns + ------- + rp_contamination : float + The refractory period contamination for the unit. + """ + if n_spikes <= 1: + return np.nan + + denom = 1 - n_v * (total_samples - 2 * n_spikes * t_c) / (n_spikes**2 * (t_r - t_c)) + if denom < 0: + return 1.0 + + rp_contamination = 1 - math.sqrt(denom) + + return rp_contamination + + def _compute_violations(obs_viol, firing_rate, spike_count, ref_period_dur, contamination_prop): contamination_rate = firing_rate * contamination_prop expected_viol = contamination_rate * ref_period_dur * 2 * spike_count @@ -1678,29 +1849,6 @@ def _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): return synchrony_counts -def _get_amplitudes_by_units(sorting_analyzer, unit_ids, peak_sign): - # used by compute_amplitude_cutoffs and compute_amplitude_medians - - if (spike_amplitudes_extension := sorting_analyzer.get_extension("spike_amplitudes")) is not None: - return spike_amplitudes_extension.get_data(outputs="by_unit", concatenated=True) - - elif sorting_analyzer.has_extension("waveforms"): - amplitudes_by_units = {} - waveforms_ext = sorting_analyzer.get_extension("waveforms") - before = waveforms_ext.nbefore - extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign) - for unit_id in unit_ids: - waveforms = waveforms_ext.get_waveforms_one_unit(unit_id, force_dense=False) - chan_id = extremum_channels_ids[unit_id] - if sorting_analyzer.is_sparse(): - chan_ind = np.where(sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id] == chan_id)[0] - else: - chan_ind = sorting_analyzer.channel_ids_to_indices([chan_id])[0] - amplitudes_by_units[unit_id] = waveforms[:, before, chan_ind] - - return amplitudes_by_units - - if HAVE_NUMBA: import numba diff --git a/src/spikeinterface/metrics/quality/quality_metrics.py b/src/spikeinterface/metrics/quality/quality_metrics.py index 8e96f4dcaf..e5cc2aa323 100644 --- a/src/spikeinterface/metrics/quality/quality_metrics.py +++ b/src/spikeinterface/metrics/quality/quality_metrics.py @@ -70,6 +70,7 @@ def _set_params( metric_params: dict | None = None, delete_existing_metrics: bool = False, metrics_to_compute: list[str] | None = None, + periods=None, # common extension kwargs peak_sign=None, seed=None, @@ -90,6 +91,7 @@ def _set_params( metric_params=metric_params, delete_existing_metrics=delete_existing_metrics, metrics_to_compute=metrics_to_compute, + periods=periods, peak_sign=peak_sign, seed=seed, skip_pc_metrics=skip_pc_metrics, diff --git a/src/spikeinterface/metrics/quality/tests/conftest.py b/src/spikeinterface/metrics/quality/tests/conftest.py deleted file mode 100644 index c2a6c6fe82..0000000000 --- a/src/spikeinterface/metrics/quality/tests/conftest.py +++ /dev/null @@ -1,85 +0,0 @@ -import pytest - -from spikeinterface.core import ( - generate_ground_truth_recording, - create_sorting_analyzer, -) - -job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") - - -def make_small_analyzer(): - recording, sorting = generate_ground_truth_recording( - durations=[2.0], - num_units=10, - seed=1205, - ) - - channel_ids_as_integers = [id for id in range(recording.get_num_channels())] - unit_ids_as_integers = [id for id in range(sorting.get_num_units())] - recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) - sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) - - sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"]) - - sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") - - extensions_to_compute = { - "random_spikes": {"seed": 1205}, - "noise_levels": {"seed": 1205}, - "waveforms": {}, - "templates": {"operators": ["average", "median"]}, - "spike_amplitudes": {}, - "spike_locations": {}, - "principal_components": {}, - } - - sorting_analyzer.compute(extensions_to_compute) - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def small_sorting_analyzer(): - return make_small_analyzer() - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - # we need high firing rate for amplitude_cutoff - recording, sorting = generate_ground_truth_recording( - durations=[ - 120.0, - ], - sampling_frequency=30_000.0, - num_channels=6, - num_units=10, - generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), - generate_unit_locations_kwargs=dict( - margin_um=5.0, - minimum_z=5.0, - maximum_z=20.0, - ), - generate_templates_kwargs=dict( - unit_params=dict( - alpha=(200.0, 500.0), - ) - ), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), - seed=1205, - ) - - channel_ids_as_integers = [id for id in range(recording.get_num_channels())] - unit_ids_as_integers = [id for id in range(sorting.get_num_units())] - recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) - sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) - - sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - - sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205) - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("waveforms", **job_kwargs) - sorting_analyzer.compute("templates") - sorting_analyzer.compute(["spike_amplitudes", "spike_locations"], **job_kwargs) - - return sorting_analyzer diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index a886fbb2e7..14b7b40e16 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -1,5 +1,4 @@ import pytest -from pathlib import Path import numpy as np from copy import deepcopy import csv @@ -12,24 +11,14 @@ synthesize_random_firings, ) -from spikeinterface.metrics.quality.utils import create_ground_truth_pc_distributions +from spikeinterface.metrics.utils import create_ground_truth_pc_distributions, create_regular_periods -# from spikeinterface.metrics.quality_metric_list import ( -# _misc_metric_name_to_func, -# ) - -from spikeinterface.metrics.quality import ( - get_quality_metric_list, - get_quality_pca_metric_list, - compute_quality_metrics, -) +from spikeinterface.metrics.quality import get_quality_metric_list, compute_quality_metrics, ComputeQualityMetrics from spikeinterface.metrics.quality.misc_metrics import ( misc_metrics_list, compute_amplitude_cutoffs, compute_presence_ratios, compute_isi_violations, - # compute_firing_rates, - # compute_num_spikes, compute_snrs, compute_refrac_period_violations, compute_sliding_rp_violations, @@ -44,7 +33,6 @@ ) from spikeinterface.metrics.quality.pca_metrics import ( - pca_metrics_list, mahalanobis_metrics, d_prime_metric, nearest_neighbors_metrics, @@ -53,257 +41,10 @@ ) -from spikeinterface.core.base import minimum_spike_dtype - -job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") - - -def test_noise_cutoff(): - """ - Generate two artifical gaussian, one truncated and one not. Check the metrics are higher for the truncated one. - """ - np.random.seed(1) - amps = np.random.normal(0, 1, 1000) - amps_trunc = amps[amps > -1] - - cutoff1, ratio1 = _noise_cutoff(amps=amps) - cutoff2, ratio2 = _noise_cutoff(amps=amps_trunc) - - assert cutoff1 <= cutoff2 - assert ratio1 <= ratio2 - - -def test_compute_new_quality_metrics(small_sorting_analyzer): - """ - Computes quality metrics then computes a subset of quality metrics, and checks - that the old quality metrics are not deleted. - """ - - qm_params = { - "presence_ratio": {"bin_duration_s": 0.1}, - "amplitude_cutoff": {"num_histogram_bins": 3}, - "firing_range": {"bin_size_s": 1}, - } - - small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) - qm_extension = small_sorting_analyzer.get_extension("quality_metrics") - calculated_metrics = list(qm_extension.get_data().keys()) - - assert calculated_metrics == ["snr"] - - small_sorting_analyzer.compute( - {"quality_metrics": {"metric_names": list(qm_params.keys()), "metric_params": qm_params}} - ) - small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) - - quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") - - # Check old metrics are not deleted and the new one is added to the data and metadata - assert set(list(quality_metric_extension.get_data().keys())) == set( - [ - "amplitude_cutoff", - "firing_range", - "presence_ratio", - "snr", - ] - ) - assert set(list(quality_metric_extension.params.get("metric_names"))) == set( - [ - "amplitude_cutoff", - "firing_range", - "presence_ratio", - "snr", - ] - ) - - # check that, when parameters are changed, the data and metadata are updated - old_snr_data = deepcopy(quality_metric_extension.get_data()["snr"].values) - small_sorting_analyzer.compute( - {"quality_metrics": {"metric_names": ["snr"], "metric_params": {"snr": {"peak_mode": "peak_to_peak"}}}} - ) - new_quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") - new_snr_data = new_quality_metric_extension.get_data()["snr"].values - - assert np.all(old_snr_data != new_snr_data) - assert new_quality_metric_extension.params["metric_params"]["snr"]["peak_mode"] == "peak_to_peak" - - -def test_metric_names_in_same_order(small_sorting_analyzer): - """ - Computes sepecified quality metrics and checks order is propagated. - """ - specified_metric_names = ["firing_range", "snr", "amplitude_cutoff"] - small_sorting_analyzer.compute("quality_metrics", metric_names=specified_metric_names) - qm_keys = small_sorting_analyzer.get_extension("quality_metrics").get_data().keys() - for i in range(3): - assert specified_metric_names[i] == qm_keys[i] - - -def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder): - """ - Computes quality metrics in binary folder format. Then computes subsets of quality - metrics and checks if they are saved correctly. - """ - - # can't use _misc_metric_name_to_func as some functions compute several qms - # e.g. isi_violation and synchrony - quality_metrics = [ - "num_spikes", - "firing_rate", - "presence_ratio", - "snr", - "isi_violations_ratio", - "isi_violations_count", - "rp_contamination", - "rp_violations", - "sliding_rp_violation", - "amplitude_cutoff", - "amplitude_median", - "amplitude_cv_median", - "amplitude_cv_range", - "sync_spike_2", - "sync_spike_4", - "sync_spike_8", - "firing_range", - "drift_ptp", - "drift_std", - "drift_mad", - "sd_ratio", - "isolation_distance", - "l_ratio", - "d_prime", - "silhouette", - "nn_hit_rate", - "nn_miss_rate", - ] - - small_sorting_analyzer.compute("quality_metrics") - - cache_folder = create_cache_folder - output_folder = cache_folder / "sorting_analyzer" - - folder_analyzer = small_sorting_analyzer.save_as(format="binary_folder", folder=output_folder) - quality_metrics_filename = output_folder / "extensions" / "quality_metrics" / "metrics.csv" - - with open(quality_metrics_filename) as metrics_file: - saved_metrics = csv.reader(metrics_file) - metric_names = next(saved_metrics) - - for metric_name in quality_metrics: - assert metric_name in metric_names - - folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=False) - - with open(quality_metrics_filename) as metrics_file: - saved_metrics = csv.reader(metrics_file) - metric_names = next(saved_metrics) - - for metric_name in quality_metrics: - assert metric_name in metric_names - - folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=True) - - with open(quality_metrics_filename) as metrics_file: - saved_metrics = csv.reader(metrics_file) - metric_names = next(saved_metrics) - - for metric_name in quality_metrics: - if metric_name == "snr": - assert metric_name in metric_names - else: - assert metric_name not in metric_names - - -def test_unit_structure_in_output(small_sorting_analyzer): - - qm_params = { - "presence_ratio": {"bin_duration_s": 0.1}, - "amplitude_cutoff": {"num_histogram_bins": 3}, - "amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3}, - "firing_range": {"bin_size_s": 1}, - "isi_violation": {"isi_threshold_ms": 10}, - "drift": {"interval_s": 1, "min_spikes_per_interval": 5}, - "sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15}, - "rp_violation": {"refractory_period_ms": 10.0, "censored_period_ms": 0.0}, - } - - for metric in misc_metrics_list: - metric_name = metric.metric_name - metric_fun = metric.metric_function - try: - qm_param = qm_params[metric_name] - except: - qm_param = {} - - result_all = metric_fun(sorting_analyzer=small_sorting_analyzer, **qm_param) - result_sub = metric_fun(sorting_analyzer=small_sorting_analyzer, unit_ids=["#4", "#9"], **qm_param) - - error = "Problem with metric: " + metric_name - - if isinstance(result_all, dict): - assert list(result_all.keys()) == ["#3", "#9", "#4"], error - assert list(result_sub.keys()) == ["#4", "#9"], error - assert result_sub["#9"] == result_all["#9"], error - assert result_sub["#4"] == result_all["#4"], error - - else: - for result_ind, result in enumerate(result_sub): - - assert list(result_all[result_ind].keys()) == ["#3", "#9", "#4"], error - assert result_sub[result_ind].keys() == set(["#4", "#9"]), error - - assert result_sub[result_ind]["#9"] == result_all[result_ind]["#9"], error - assert result_sub[result_ind]["#4"] == result_all[result_ind]["#4"], error - - -def test_unit_id_order_independence(small_sorting_analyzer): - """ - Takes two almost-identical sorting_analyzers, whose unit_ids are in different orders and have different labels, - and checks that their calculated quality metrics are independent of the ordering and labelling. - """ - - recording = small_sorting_analyzer.recording - sorting = small_sorting_analyzer.sorting.select_units(["#4", "#9", "#3"], [1, 7, 2]) - - small_sorting_analyzer_2 = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") - - extensions_to_compute = { - "random_spikes": {"seed": 1205}, - "noise_levels": {"seed": 1205}, - "waveforms": {}, - "templates": {}, - "spike_amplitudes": {}, - "spike_locations": {}, - "principal_components": {}, - } - - small_sorting_analyzer_2.compute(extensions_to_compute) - - # need special params to get non-nan results on a short recording - qm_params = { - "presence_ratio": {"bin_duration_s": 0.1}, - "amplitude_cutoff": {"num_histogram_bins": 3}, - "amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3}, - "firing_range": {"bin_size_s": 1}, - "isi_violation": {"isi_threshold_ms": 10}, - "drift": {"interval_s": 1, "min_spikes_per_interval": 5}, - "sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15}, - } - - quality_metrics_1 = compute_quality_metrics( - small_sorting_analyzer, metric_names=get_quality_metric_list(), metric_params=qm_params, skip_pc_metrics=True - ) - quality_metrics_2 = compute_quality_metrics( - small_sorting_analyzer_2, metric_names=get_quality_metric_list(), metric_params=qm_params, skip_pc_metrics=True - ) - - for metric, metric_2_data in quality_metrics_2.items(): - error = "Problem with the metric " + metric - assert quality_metrics_1[metric]["#3"] == metric_2_data[2], error - assert quality_metrics_1[metric]["#9"] == metric_2_data[7], error - assert quality_metrics_1[metric]["#4"] == metric_2_data[1], error +from spikeinterface.core.base import minimum_spike_dtype, unit_period_dtype +### HELPER FUNCTIONS AND FIXTURES ### def _sorting_violation(): max_time = 100.0 sampling_frequency = 30000 @@ -334,7 +75,6 @@ def _sorting_violation(): def _sorting_analyzer_violations(): - sorting = _sorting_violation() duration = (sorting.to_spike_vector()["sample_index"][-1] + 1) / sorting.sampling_frequency @@ -356,6 +96,40 @@ def sorting_analyzer_violations(): return _sorting_analyzer_violations() +@pytest.fixture +def periods_simple(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + periods = create_regular_periods(sorting_analyzer, num_periods=5) + return periods + + +@pytest.fixture +def periods_violations(sorting_analyzer_violations): + sorting_analyzer = sorting_analyzer_violations + periods = create_regular_periods(sorting_analyzer, num_periods=5) + return periods + + +# Common job kwargs +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") + + +### LOW-LEVEL TESTS ### +def test_noise_cutoff(): + """ + Generate two artifical gaussian, one truncated and one not. Check the metrics are higher for the truncated one. + """ + np.random.seed(1) + amps = np.random.normal(0, 1, 1000) + amps_trunc = amps[amps > -1] + + cutoff1, ratio1 = _noise_cutoff(amps=amps) + cutoff2, ratio2 = _noise_cutoff(amps=amps_trunc) + + assert cutoff1 <= cutoff2 + assert ratio1 <= ratio2 + + def test_synchrony_counts_no_sync(): spike_times, spike_units = synthesize_random_firings(num_units=1, duration=1, firing_rates=1.0) @@ -488,22 +262,17 @@ def test_simplified_silhouette_score_metrics(): assert sim_sil_score1 < sim_sil_score2 -# def test_calculate_firing_rate_num_spikes(sorting_analyzer_simple): -# sorting_analyzer = sorting_analyzer_simple -# firing_rates = compute_firing_rates(sorting_analyzer) -# num_spikes = compute_num_spikes(sorting_analyzer) - -# testing method accuracy with magic number is not a good pratcice, I remove this. -# firing_rates_gt = {0: 10.01, 1: 5.03, 2: 5.09} -# num_spikes_gt = {0: 1001, 1: 503, 2: 509} -# assert np.allclose(list(firing_rates_gt.values()), list(firing_rates.values()), rtol=0.05) -# np.testing.assert_array_equal(list(num_spikes_gt.values()), list(num_spikes.values())) - - +### TEST METRICS FUNCTIONS ### def test_calculate_firing_range(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple - firing_ranges = compute_firing_ranges(sorting_analyzer) - print(firing_ranges) + firing_ranges = compute_firing_ranges(sorting_analyzer, bin_size_s=1) + periods = create_regular_periods(sorting_analyzer, num_periods=5, bin_size_s=1) + firing_ranges_periods = compute_firing_ranges(sorting_analyzer, periods=periods, bin_size_s=1) + assert firing_ranges == firing_ranges_periods + + empty_periods = np.empty(0, dtype=unit_period_dtype) + firing_ranges_empty = compute_firing_ranges(sorting_analyzer, periods=empty_periods) + assert np.all(np.isnan(np.array(list(firing_ranges_empty.values())))) with pytest.warns(UserWarning) as w: firing_ranges_nan = compute_firing_ranges( @@ -516,6 +285,13 @@ def test_calculate_amplitude_cutoff(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple # spike_amps = sorting_analyzer.get_extension("spike_amplitudes").get_data() amp_cuts = compute_amplitude_cutoffs(sorting_analyzer, num_histogram_bins=10) + periods = create_regular_periods(sorting_analyzer, num_periods=5) + amp_cuts_periods = compute_amplitude_cutoffs(sorting_analyzer, periods=periods, num_histogram_bins=10) + assert amp_cuts == amp_cuts_periods + + empty_periods = np.empty(0, dtype=unit_period_dtype) + amp_cuts_empty = compute_amplitude_cutoffs(sorting_analyzer, periods=empty_periods) + assert np.all(np.isnan(np.array(list(amp_cuts_empty.values())))) # print(amp_cuts) # testing method accuracy with magic number is not a good pratcice, I remove this. @@ -527,18 +303,39 @@ def test_calculate_amplitude_median(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple # spike_amps = sorting_analyzer.get_extension("spike_amplitudes").get_data() amp_medians = compute_amplitude_medians(sorting_analyzer) - # print(amp_medians) + periods = create_regular_periods(sorting_analyzer, num_periods=5) + amp_medians_periods = compute_amplitude_medians(sorting_analyzer, periods=periods) + assert amp_medians == amp_medians_periods + + empty_periods = np.empty(0, dtype=unit_period_dtype) + amp_medians_empty = compute_amplitude_medians(sorting_analyzer, periods=empty_periods) + assert np.all(np.isnan(np.array(list(amp_medians_empty.values())))) # testing method accuracy with magic number is not a good pratcice, I remove this. # amp_medians_gt = {0: 130.77323354628675, 1: 130.7461997791725, 2: 130.7461997791725} # assert np.allclose(list(amp_medians_gt.values()), list(amp_medians.values()), rtol=0.05) -def test_calculate_amplitude_cv_metrics(sorting_analyzer_simple): +def test_calculate_amplitude_cv_metrics(sorting_analyzer_simple, periods_simple): sorting_analyzer = sorting_analyzer_simple amp_cv_median, amp_cv_range = compute_amplitude_cv_metrics(sorting_analyzer, average_num_spikes_per_bin=20) - print(amp_cv_median) - print(amp_cv_range) + periods = periods_simple + amp_cv_median_periods, amp_cv_range_periods = compute_amplitude_cv_metrics( + sorting_analyzer, + periods=periods, + average_num_spikes_per_bin=20, + ) + assert amp_cv_median == amp_cv_median_periods + assert amp_cv_range == amp_cv_range_periods + + empty_periods = np.empty(0, dtype=unit_period_dtype) + amp_cv_median_empty, amp_cv_range_empty = compute_amplitude_cv_metrics( + sorting_analyzer, + periods=empty_periods, + average_num_spikes_per_bin=20, + ) + assert np.all(np.isnan(np.array(list(amp_cv_median_empty.values())))) + assert np.all(np.isnan(np.array(list(amp_cv_range_empty.values())))) # amps_scalings = compute_amplitude_scalings(sorting_analyzer) sorting_analyzer.compute("amplitude_scalings", **job_kwargs) @@ -548,34 +345,56 @@ def test_calculate_amplitude_cv_metrics(sorting_analyzer_simple): amplitude_extension="amplitude_scalings", min_num_bins=5, ) - print(amp_cv_median_scalings) - print(amp_cv_range_scalings) + amp_cv_median_scalings_periods, amp_cv_range_scalings_periods = compute_amplitude_cv_metrics( + sorting_analyzer, + periods=periods, + average_num_spikes_per_bin=20, + amplitude_extension="amplitude_scalings", + min_num_bins=5, + ) + assert amp_cv_median_scalings == amp_cv_median_scalings_periods + assert amp_cv_range_scalings == amp_cv_range_scalings_periods -def test_calculate_snrs(sorting_analyzer_simple): +def test_calculate_snrs(sorting_analyzer_simple, periods_simple): sorting_analyzer = sorting_analyzer_simple snrs = compute_snrs(sorting_analyzer) - print(snrs) + # SNR doesn't support periods # testing method accuracy with magic number is not a good pratcice, I remove this. # snrs_gt = {0: 12.92, 1: 12.99, 2: 12.99} # assert np.allclose(list(snrs_gt.values()), list(snrs.values()), rtol=0.05) -def test_calculate_presence_ratio(sorting_analyzer_simple): +def test_calculate_presence_ratio(sorting_analyzer_simple, periods_simple): sorting_analyzer = sorting_analyzer_simple ratios = compute_presence_ratios(sorting_analyzer, bin_duration_s=10) - print(ratios) + periods = periods_simple + ratios_periods = compute_presence_ratios(sorting_analyzer, periods=periods, bin_duration_s=10) + assert ratios == ratios_periods + empty_periods = np.empty(0, dtype=unit_period_dtype) + ratios_periods_empty = compute_presence_ratios(sorting_analyzer, periods=empty_periods) + assert np.all(np.isnan(np.array(list(ratios_periods_empty.values())))) # testing method accuracy with magic number is not a good pratcice, I remove this. # ratios_gt = {0: 1.0, 1: 1.0, 2: 1.0} # np.testing.assert_array_equal(list(ratios_gt.values()), list(ratios.values())) -def test_calculate_isi_violations(sorting_analyzer_violations): +def test_calculate_isi_violations(sorting_analyzer_violations, periods_violations): sorting_analyzer = sorting_analyzer_violations isi_viol, counts = compute_isi_violations(sorting_analyzer, isi_threshold_ms=1, min_isi_ms=0.0) - print(isi_viol) + periods = periods_violations + isi_viol_periods, counts_periods = compute_isi_violations( + sorting_analyzer, isi_threshold_ms=1, min_isi_ms=0.0, periods=periods + ) + assert isi_viol == isi_viol_periods + assert counts == counts_periods + + empty_periods = np.empty(0, dtype=unit_period_dtype) + isi_viol_empty, isi_counts_empty = compute_isi_violations(sorting_analyzer, periods=empty_periods) + assert np.all(np.isnan(np.array(list(isi_viol_empty.values())))) + assert np.array_equal(np.array(list(isi_counts_empty.values())), -1 * np.ones(len(sorting_analyzer.unit_ids))) # testing method accuracy with magic number is not a good pratcice, I remove this. # isi_viol_gt = {0: 0.0998002996004994, 1: 0.7904857139469347, 2: 1.929898371551754} @@ -584,22 +403,44 @@ def test_calculate_isi_violations(sorting_analyzer_violations): # np.testing.assert_array_equal(list(counts_gt.values()), list(counts.values())) -def test_calculate_sliding_rp_violations(sorting_analyzer_violations): +def test_calculate_sliding_rp_violations(sorting_analyzer_violations, periods_violations): sorting_analyzer = sorting_analyzer_violations contaminations = compute_sliding_rp_violations(sorting_analyzer, bin_size_ms=0.25, window_size_s=1) - print(contaminations) + periods = periods_violations + contaminations_periods = compute_sliding_rp_violations( + sorting_analyzer, periods=periods, bin_size_ms=0.25, window_size_s=1 + ) + assert contaminations == contaminations_periods + + empty_periods = np.empty(0, dtype=unit_period_dtype) + contaminations_periods_empty = compute_sliding_rp_violations( + sorting_analyzer, periods=empty_periods, bin_size_ms=0.25, window_size_s=1 + ) + assert np.all(np.isnan(np.array(list(contaminations_periods_empty.values())))) # testing method accuracy with magic number is not a good pratcice, I remove this. # contaminations_gt = {0: 0.03, 1: 0.185, 2: 0.325} # assert np.allclose(list(contaminations_gt.values()), list(contaminations.values()), rtol=0.05) -def test_calculate_rp_violations(sorting_analyzer_violations): +def test_calculate_rp_violations(sorting_analyzer_violations, periods_violations): sorting_analyzer = sorting_analyzer_violations rp_contamination, counts = compute_refrac_period_violations( sorting_analyzer, refractory_period_ms=1, censored_period_ms=0.0 ) - print(rp_contamination, counts) + periods = periods_violations + rp_contamination_periods, counts_periods = compute_refrac_period_violations( + sorting_analyzer, refractory_period_ms=1, censored_period_ms=0.0, periods=periods + ) + assert rp_contamination == rp_contamination_periods + assert counts == counts_periods + + empty_periods = np.empty(0, dtype=unit_period_dtype) + rp_contamination_empty, counts_empty = compute_refrac_period_violations( + sorting_analyzer, refractory_period_ms=1, censored_period_ms=0.0, periods=empty_periods + ) + assert np.all(np.isnan(np.array(list(rp_contamination_empty.values())))) + assert np.array_equal(np.array(list(counts_empty.values())), -1 * np.ones(len(sorting_analyzer.unit_ids))) # testing method accuracy with magic number is not a good pratcice, I remove this. # counts_gt = {0: 2, 1: 4, 2: 10} @@ -619,13 +460,27 @@ def test_calculate_rp_violations(sorting_analyzer_violations): assert np.isnan(rp_contamination[1]) -def test_synchrony_metrics(sorting_analyzer_simple): +def test_synchrony_metrics(sorting_analyzer_simple, periods_simple): sorting_analyzer = sorting_analyzer_simple sorting = sorting_analyzer.sorting synchrony_metrics = compute_synchrony_metrics(sorting_analyzer) + periods = periods_simple + synchrony_metrics_periods = compute_synchrony_metrics(sorting_analyzer, periods=periods) + assert synchrony_metrics == synchrony_metrics_periods + + empty_periods = np.empty(0, dtype=unit_period_dtype) + synchrony_metrics_empty = compute_synchrony_metrics(sorting_analyzer, periods=empty_periods) + assert np.array_equal( + np.array(list(synchrony_metrics_empty.sync_spike_2.values())), -1 * np.ones(len(sorting_analyzer.unit_ids)) + ) + assert np.array_equal( + np.array(list(synchrony_metrics_empty.sync_spike_4.values())), -1 * np.ones(len(sorting_analyzer.unit_ids)) + ) + assert np.array_equal( + np.array(list(synchrony_metrics_empty.sync_spike_8.values())), -1 * np.ones(len(sorting_analyzer.unit_ids)) + ) synchrony_sizes = np.array([2, 4, 8]) - # check returns for size in synchrony_sizes: assert f"sync_spike_{size}" in synchrony_metrics._fields @@ -678,6 +533,22 @@ def test_calculate_drift_metrics(sorting_analyzer_simple): drifts_ptps, drifts_stds, drift_mads = compute_drift_metrics( sorting_analyzer, interval_s=10, min_spikes_per_interval=10 ) + periods = create_regular_periods(sorting_analyzer, num_periods=5, bin_size_s=10) + drifts_ptps_periods, drifts_stds_periods, drift_mads_periods = compute_drift_metrics( + sorting_analyzer, periods=periods, min_spikes_per_interval=10, interval_s=10 + ) + assert drifts_ptps == drifts_ptps_periods + assert drifts_stds == drifts_stds_periods + assert drift_mads == drift_mads_periods + + # calculate num spikes with empty periods + empty_periods = np.empty(0, dtype=unit_period_dtype) + drifts_ptps_empty, drifts_stds_empty, drift_mads_empty = compute_drift_metrics( + sorting_analyzer_simple, periods=empty_periods + ) + assert np.all(np.isnan(np.array(list(drifts_ptps_empty.values())))) + assert np.all(np.isnan(np.array(list(drifts_stds_empty.values())))) + assert np.all(np.isnan(np.array(list(drift_mads_empty.values())))) # print(drifts_ptps, drifts_stds, drift_mads) @@ -690,25 +561,234 @@ def test_calculate_drift_metrics(sorting_analyzer_simple): # assert np.allclose(list(drift_mads_gt.values()), list(drift_mads.values()), rtol=0.05) -def test_calculate_sd_ratio(sorting_analyzer_simple): +def test_calculate_sd_ratio(sorting_analyzer_simple, periods_simple): sd_ratio = compute_sd_ratio( sorting_analyzer_simple, ) + periods = periods_simple + sd_ratio_periods = compute_sd_ratio(sorting_analyzer_simple, periods=periods) + assert sd_ratio == sd_ratio_periods assert np.all(list(sd_ratio.keys()) == sorting_analyzer_simple.unit_ids) + + # calculate num spikes with empty periods + empty_periods = np.empty(0, dtype=unit_period_dtype) + sd_ratios_empty_periods = compute_sd_ratio(sorting_analyzer_simple, periods=empty_periods) + assert np.all(np.isnan(np.array(list(sd_ratios_empty_periods.values())))) # @aurelien can you check this, this is not working anymore # assert np.allclose(list(sd_ratio.values()), 1, atol=0.25, rtol=0) -if __name__ == "__main__": +### MACHINERY TESTS ### +def test_compute_new_quality_metrics(small_sorting_analyzer): + """ + Computes quality metrics then computes a subset of quality metrics, and checks + that the old quality metrics are not deleted. + """ + + qm_params = { + "presence_ratio": {"bin_duration_s": 0.1}, + "amplitude_cutoff": {"num_histogram_bins": 3}, + "firing_range": {"bin_size_s": 1}, + } - sorting_analyzer = _sorting_analyzer_simple() - print(sorting_analyzer) + small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) + qm_extension = small_sorting_analyzer.get_extension("quality_metrics") + calculated_metrics = list(qm_extension.get_data().keys()) - test_unit_structure_in_output(_small_sorting_analyzer()) + assert calculated_metrics == ["snr"] - # test_calculate_firing_rate_num_spikes(sorting_analyzer) + small_sorting_analyzer.compute( + {"quality_metrics": {"metric_names": list(qm_params.keys()), "metric_params": qm_params}} + ) + small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) + + quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") + + # Check old metrics are not deleted and the new one is added to the data and metadata + assert set(list(quality_metric_extension.get_data().keys())) == set( + [ + "amplitude_cutoff", + "firing_range", + "presence_ratio", + "snr", + ] + ) + assert set(list(quality_metric_extension.params.get("metric_names"))) == set( + [ + "amplitude_cutoff", + "firing_range", + "presence_ratio", + "snr", + ] + ) + + # check that, when parameters are changed, the data and metadata are updated + old_snr_data = deepcopy(quality_metric_extension.get_data()["snr"].values) + small_sorting_analyzer.compute( + {"quality_metrics": {"metric_names": ["snr"], "metric_params": {"snr": {"peak_mode": "peak_to_peak"}}}} + ) + new_quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") + new_snr_data = new_quality_metric_extension.get_data()["snr"].values + + assert np.all(old_snr_data != new_snr_data) + assert new_quality_metric_extension.params["metric_params"]["snr"]["peak_mode"] == "peak_to_peak" + + +def test_metric_names_in_same_order(small_sorting_analyzer): + """ + Computes sepecified quality metrics and checks order is propagated. + """ + specified_metric_names = ["firing_range", "snr", "amplitude_cutoff"] + small_sorting_analyzer.compute("quality_metrics", metric_names=specified_metric_names) + qm_keys = small_sorting_analyzer.get_extension("quality_metrics").get_data().keys() + for i in range(3): + assert specified_metric_names[i] == qm_keys[i] + + +def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder): + """ + Computes quality metrics in binary folder format. Then computes subsets of quality + metrics and checks if they are saved correctly. + """ + + # can't use _misc_metric_name_to_func as some functions compute several qms + # e.g. isi_violation and synchrony + quality_metric_columns = ComputeQualityMetrics.get_metric_columns() + all_metrics = ComputeQualityMetrics.get_available_metric_names() + small_sorting_analyzer.compute("quality_metrics", metric_names=all_metrics) + + cache_folder = create_cache_folder + output_folder = cache_folder / "sorting_analyzer" + + folder_analyzer = small_sorting_analyzer.save_as(format="binary_folder", folder=output_folder) + quality_metrics_filename = output_folder / "extensions" / "quality_metrics" / "metrics.csv" + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metric_columns: + assert metric_name in metric_names + + folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=False) + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metric_columns: + assert metric_name in metric_names + + folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=True) + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metric_columns: + if metric_name == "snr": + assert metric_name in metric_names + else: + assert metric_name not in metric_names + + +def test_unit_structure_in_output(small_sorting_analyzer): + + qm_params = { + "presence_ratio": {"bin_duration_s": 0.1}, + "amplitude_cutoff": {"num_histogram_bins": 3}, + "amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3}, + "firing_range": {"bin_size_s": 1}, + "isi_violation": {"isi_threshold_ms": 10}, + "drift": {"interval_s": 1, "min_spikes_per_interval": 5, "min_fraction_valid_intervals": 0.2}, + "sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15}, + "rp_violation": {"refractory_period_ms": 10.0, "censored_period_ms": 0.0}, + } + + for metric in misc_metrics_list: + metric_name = metric.metric_name + metric_fun = metric.metric_function + try: + qm_param = qm_params[metric_name] + except: + qm_param = {} + + result_all = metric_fun(sorting_analyzer=small_sorting_analyzer, **qm_param) + result_sub = metric_fun(sorting_analyzer=small_sorting_analyzer, unit_ids=["#4", "#9"], **qm_param) + + error = "Problem with metric: " + metric_name + + if isinstance(result_all, dict): + assert list(result_all.keys()) == ["#3", "#9", "#4"], error + assert list(result_sub.keys()) == ["#4", "#9"], error + assert result_sub["#9"] == result_all["#9"], error + assert result_sub["#4"] == result_all["#4"], error + + else: + for result_ind, result in enumerate(result_sub): + + assert list(result_all[result_ind].keys()) == ["#3", "#9", "#4"], error + assert result_sub[result_ind].keys() == set(["#4", "#9"]), error + + assert result_sub[result_ind]["#9"] == result_all[result_ind]["#9"], error + assert result_sub[result_ind]["#4"] == result_all[result_ind]["#4"], error + + +def test_unit_id_order_independence(small_sorting_analyzer): + """ + Takes two almost-identical sorting_analyzers, whose unit_ids are in different orders and have different labels, + and checks that their calculated quality metrics are independent of the ordering and labelling. + """ + + recording = small_sorting_analyzer.recording + sorting = small_sorting_analyzer.sorting.select_units(["#4", "#9", "#3"], [1, 7, 2]) + + small_sorting_analyzer_2 = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") + + extensions_to_compute = { + "random_spikes": {"seed": 1205}, + "noise_levels": {"seed": 1205}, + "waveforms": {}, + "templates": {}, + "spike_amplitudes": {}, + "spike_locations": {}, + "principal_components": {}, + } + small_sorting_analyzer_2.compute(extensions_to_compute) + + # need special params to get non-nan results on a short recording + qm_params = { + "presence_ratio": {"bin_duration_s": 0.1}, + "amplitude_cutoff": {"num_histogram_bins": 3}, + "amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3}, + "firing_range": {"bin_size_s": 1}, + "isi_violation": {"isi_threshold_ms": 10}, + "drift": {"interval_s": 1, "min_spikes_per_interval": 5}, + "sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15}, + } + + quality_metrics_1 = compute_quality_metrics( + small_sorting_analyzer, metric_names=get_quality_metric_list(), metric_params=qm_params, skip_pc_metrics=True + ) + quality_metrics_2 = compute_quality_metrics( + small_sorting_analyzer_2, metric_names=get_quality_metric_list(), metric_params=qm_params, skip_pc_metrics=True + ) + + for metric, metric_2_data in quality_metrics_2.items(): + error = "Problem with the metric " + metric + assert quality_metrics_1[metric]["#3"] == metric_2_data[2], error + assert quality_metrics_1[metric]["#9"] == metric_2_data[7], error + assert quality_metrics_1[metric]["#4"] == metric_2_data[1], error + + +if __name__ == "__main__": + pass + # sorting_analyzer = _sorting_analyzer_simple() + # print(sorting_analyzer) + # test_unit_structure_in_output(_small_sorting_analyzer()) + # test_calculate_firing_rate_num_spikes(sorting_analyzer) # test_calculate_snrs(sorting_analyzer) # test_calculate_amplitude_cutoff(sorting_analyzer) # test_calculate_presence_ratio(sorting_analyzer) diff --git a/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py b/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py index ec72fdc178..2e87002018 100644 --- a/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py @@ -168,7 +168,8 @@ def test_empty_units(sorting_analyzer_simple): for col in metrics_empty.columns: all_nans = np.all(isnull(metrics_empty.loc[empty_unit_ids, col].values)) all_zeros = np.all(metrics_empty.loc[empty_unit_ids, col].values == 0) - assert all_nans or all_zeros + all_neg_ones = np.all(metrics_empty.loc[empty_unit_ids, col].values == -1) + assert all_nans or all_zeros or all_neg_ones, f"Column {col} failed the empty unit test" if __name__ == "__main__": diff --git a/src/spikeinterface/metrics/quality/utils.py b/src/spikeinterface/metrics/quality/utils.py deleted file mode 100644 index 844a7da7f5..0000000000 --- a/src/spikeinterface/metrics/quality/utils.py +++ /dev/null @@ -1,47 +0,0 @@ -from __future__ import annotations - -import numpy as np - - -def create_ground_truth_pc_distributions(center_locations, total_points): - """ - Simulate PCs as multivariate Gaussians, for testing PC-based quality metrics - Values are created for only one channel and vary along one dimension. - - Parameters - ---------- - center_locations : array-like (units, ) or (channels, units) - Mean of the multivariate gaussian at each channel for each unit. - total_points : array-like - Number of points in each unit distribution. - - Returns - ------- - all_pcs : numpy.ndarray - PC scores for each point. - all_labels : numpy.array - Labels for each point. - """ - from scipy.stats import multivariate_normal - - np.random.seed(0) - - if len(np.array(center_locations).shape) == 1: - distributions = [ - multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) - for center, size in zip(center_locations, total_points) - ] - all_pcs = np.concatenate(distributions, axis=0) - - else: - all_pcs = np.empty((np.sum(total_points), 3, center_locations.shape[0])) - for channel in range(center_locations.shape[0]): - distributions = [ - multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) - for center, size in zip(center_locations[channel], total_points) - ] - all_pcs[:, :, channel] = np.concatenate(distributions, axis=0) - - all_labels = np.concatenate([np.ones((total_points[i],), dtype="int") * i for i in range(len(total_points))]) - - return all_pcs, all_labels diff --git a/src/spikeinterface/metrics/spiketrain/metrics.py b/src/spikeinterface/metrics/spiketrain/metrics.py index 0467cc15da..652be32955 100644 --- a/src/spikeinterface/metrics/spiketrain/metrics.py +++ b/src/spikeinterface/metrics/spiketrain/metrics.py @@ -1,8 +1,10 @@ import numpy as np + from spikeinterface.core.analyzer_extension_core import BaseMetric +from spikeinterface.metrics.utils import compute_total_durations_per_unit -def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): +def compute_num_spikes(sorting_analyzer, unit_ids=None, periods=None): """ Compute the number of spike across segments. @@ -12,24 +14,23 @@ def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): A SortingAnalyzer object. unit_ids : list or None The list of unit ids to compute the number of spikes. If None, all units are used. + periods : array of unit_period_dtype, default: None + Periods to consider for each unit. Returns ------- num_spikes : dict The number of spikes, across all segments, for each unit ID. """ - sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods) if unit_ids is None: unit_ids = sorting.unit_ids - + # re-order dict to match unit_ids order + count_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) num_spikes = {} - - total_num_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit() - for unit_id in unit_ids: - num_spikes[unit_id] = total_num_spikes[unit_id] - + num_spikes[unit_id] = count_spikes[unit_id] return num_spikes @@ -39,9 +40,10 @@ class NumSpikes(BaseMetric): metric_params = {} metric_descriptions = {"num_spikes": "Total number of spikes for each unit across all segments."} metric_columns = {"num_spikes": int} + supports_periods = True -def compute_firing_rates(sorting_analyzer, unit_ids=None): +def compute_firing_rates(sorting_analyzer, unit_ids=None, periods=None): """ Compute the firing rate across segments. @@ -51,6 +53,8 @@ def compute_firing_rates(sorting_analyzer, unit_ids=None): A SortingAnalyzer object. unit_ids : list or None The list of unit ids to compute the firing rate. If None, all units are used. + periods : array of unit_period_dtype, default: None + Periods to consider for each unit. Returns ------- @@ -59,17 +63,18 @@ def compute_firing_rates(sorting_analyzer, unit_ids=None): """ sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods) if unit_ids is None: unit_ids = sorting.unit_ids - total_duration = sorting_analyzer.get_total_duration() + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) firing_rates = {} - num_spikes = compute_num_spikes(sorting_analyzer, unit_ids=unit_ids) + num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) for unit_id in unit_ids: if num_spikes[unit_id] == 0: firing_rates[unit_id] = np.nan else: - firing_rates[unit_id] = num_spikes[unit_id] / total_duration + firing_rates[unit_id] = num_spikes[unit_id] / total_durations[unit_id] return firing_rates @@ -79,6 +84,7 @@ class FiringRate(BaseMetric): metric_params = {} metric_descriptions = {"firing_rate": "Firing rate (spikes per second) for each unit across all segments."} metric_columns = {"firing_rate": float} + supports_periods = True spiketrain_metrics = [NumSpikes, FiringRate] diff --git a/src/spikeinterface/metrics/spiketrain/tests/test_metric_functions.py b/src/spikeinterface/metrics/spiketrain/tests/test_metric_functions.py new file mode 100644 index 0000000000..7577c767d6 --- /dev/null +++ b/src/spikeinterface/metrics/spiketrain/tests/test_metric_functions.py @@ -0,0 +1,33 @@ +import numpy as np + +from spikeinterface.core.base import unit_period_dtype +from spikeinterface.metrics.utils import create_regular_periods +from spikeinterface.metrics.spiketrain.metrics import compute_num_spikes, compute_firing_rates + + +def test_calculate_num_spikes(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + # spike_amps = sorting_analyzer.get_extension("spike_amplitudes").get_data() + num_spikes = compute_num_spikes(sorting_analyzer) + periods = create_regular_periods(sorting_analyzer, num_periods=5) + num_spikes_periods = compute_num_spikes(sorting_analyzer, periods=periods) + assert num_spikes == num_spikes_periods + + # calculate num spikes with empty periods + empty_periods = np.empty(0, dtype=unit_period_dtype) + num_spikes_empty_periods = compute_num_spikes(sorting_analyzer, periods=empty_periods) + assert num_spikes_empty_periods == {unit_id: 0 for unit_id in sorting_analyzer.sorting.unit_ids} + + +def test_calculate_firing_rates(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + # spike_amps = sorting_analyzer.get_extension("spike_amplitudes").get_data() + firing_rates = compute_firing_rates(sorting_analyzer) + periods = create_regular_periods(sorting_analyzer, num_periods=5) + firing_rates_periods = compute_firing_rates(sorting_analyzer, periods=periods) + assert firing_rates == firing_rates_periods + + # calculate num spikes with empty periods + empty_periods = np.empty(0, dtype=unit_period_dtype) + firing_rates_empty_periods = compute_firing_rates(sorting_analyzer, periods=empty_periods) + assert np.all(np.isnan(np.array(list(firing_rates_empty_periods.values())))) diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py index e27f16963d..85ef9e22cb 100644 --- a/src/spikeinterface/metrics/template/template_metrics.py +++ b/src/spikeinterface/metrics/template/template_metrics.py @@ -131,6 +131,7 @@ def _set_params( metric_params: dict | None = None, delete_existing_metrics: bool = False, metrics_to_compute: list[str] | None = None, + periods=None, # common extension kwargs peak_sign="neg", upsampling_factor=10, @@ -160,6 +161,7 @@ def _set_params( metric_params=metric_params, delete_existing_metrics=delete_existing_metrics, metrics_to_compute=metrics_to_compute, + periods=periods, # template metrics do not use periods peak_sign=peak_sign, upsampling_factor=upsampling_factor, include_multi_channel_metrics=include_multi_channel_metrics, diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py new file mode 100644 index 0000000000..b7b94294eb --- /dev/null +++ b/src/spikeinterface/metrics/utils.py @@ -0,0 +1,212 @@ +from __future__ import annotations + +import numpy as np +from spikeinterface.core.base import unit_period_dtype + + +def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, periods=None, concatenated=True): + """ + Compute bin edges for units, optionally taking into account periods. + + Parameters + ---------- + sorting : Sorting + Sorting object containing unit information. + segment_samples : list or array-like + Number of samples in each segment. + bin_duration_s : float, default: 1 + Duration of each bin in seconds + periods : array of unit_period_dtype, default: None + Periods to consider for each unit + concatenated : bool, default: True + Wheter the bins are concatenated across segments or not. + If False, the bin edges are computed per segment and the first index of each segment is 0. + If True, the bin edges are computed on the concatenated segments, with the correct offsets. + + Returns + ------- + dict + Bin edges for each unit. If concatenated is True, the bin edges are a 1D array. + If False, the bin edges are a list of arrays, one per segment. + """ + bin_edges_for_units = {} + num_segments = len(segment_samples) + bin_duration_samples = int(bin_duration_s * sorting.sampling_frequency) + + if periods is not None: + for unit_id in sorting.unit_ids: + unit_index = sorting.id_to_index(unit_id) + periods_unit = periods[periods["unit_index"] == unit_index] + bin_edges = [] + for seg_index in range(num_segments): + seg_periods = periods_unit[periods_unit["segment_index"] == seg_index] + if len(seg_periods) == 0: + if not concatenated: + bin_edges.append(np.array([])) + continue + seg_start = np.sum(segment_samples[:seg_index]) if concatenated else 0 + bin_edges_segment = [] + for period in seg_periods: + start_sample = seg_start + period["start_sample_index"] + end_sample = seg_start + period["end_sample_index"] + end_sample = end_sample // bin_duration_samples * bin_duration_samples + 1 # align to bin + bin_edges_segment.extend(np.arange(start_sample, end_sample, bin_duration_samples)) + bin_edges_segment = np.unique(np.array(bin_edges_segment)) + if concatenated: + bin_edges.extend(bin_edges_segment) + else: + bin_edges.append(bin_edges_segment) + bin_edges_for_units[unit_id] = bin_edges + else: + for unit_id in sorting.unit_ids: + bin_edges = [] + for seg_index in range(num_segments): + seg_start = np.sum(segment_samples[:seg_index]) if concatenated else 0 + seg_end = seg_start + segment_samples[seg_index] + # for segments which are not the last, we don't need to correct the end + # since the first index of the next segment will be the end of the current segment + if seg_index == num_segments - 1: + seg_end = seg_end // bin_duration_samples * bin_duration_samples + 1 # align to bin + bin_edges_segment = np.arange(seg_start, seg_end, bin_duration_samples) + if concatenated: + bin_edges.extend(bin_edges_segment) + else: + bin_edges.append(bin_edges_segment) + bin_edges_for_units[unit_id] = bin_edges + return bin_edges_for_units + + +def compute_total_samples_per_unit(sorting_analyzer, periods=None): + """ + Get total number of samples for each unit, optionally taking into account periods. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object. + periods : array of unit_period_dtype, default: None + Periods to consider for each unit. + + Returns + ------- + dict + Total number of samples for each unit. + """ + if periods is not None: + total_samples_array = np.zeros(len(sorting_analyzer.unit_ids), dtype="int64") + sorting = sorting_analyzer.sorting + for period in periods: + unit_index = period["unit_index"] + total_samples_array[unit_index] += period["end_sample_index"] - period["start_sample_index"] + total_samples = dict(zip(sorting.unit_ids, total_samples_array)) + else: + total = sorting_analyzer.get_total_samples() + total_samples = {unit_id: total for unit_id in sorting_analyzer.unit_ids} + return total_samples + + +def compute_total_durations_per_unit(sorting_analyzer, periods=None): + """ + Compute total duration for each unit, optionally taking into account periods. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object. + periods : array of unit_period_dtype, default: None + Periods to consider for each unit. + + Returns + ------- + dict + Total duration for each unit. + """ + total_samples = compute_total_samples_per_unit(sorting_analyzer, periods=periods) + total_durations = { + unit_id: samples / sorting_analyzer.sampling_frequency for unit_id, samples in total_samples.items() + } + return total_durations + + +def create_regular_periods(sorting_analyzer, num_periods, bin_size_s=None): + """ + Computes and sets periods for each unit in the sorting analyzer. + The periods span the total duration of the recording, but divide it into + smaller periods either by specifying the number of periods or the size of each bin. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer containing the units and recording information. + num_periods : int + The number of periods to divide the total duration into (used if bin_size_s is None). + bin_size_s : float, defaut: None + If given, periods will be multiple of this size in seconds. + + Returns + ------- + periods + np.ndarray of dtype unit_period_dtype containing the segment, start, end samples and unit index. + """ + all_periods = [] + for segment_index in range(sorting_analyzer.recording.get_num_segments()): + samples_per_period = sorting_analyzer.get_num_samples(segment_index) // num_periods + if bin_size_s is not None: + bin_size_samples = int(bin_size_s * sorting_analyzer.sampling_frequency) + samples_per_period = samples_per_period // bin_size_samples * bin_size_samples + num_periods = int(np.round(sorting_analyzer.get_num_samples(segment_index) / samples_per_period)) + for unit_index, unit_id in enumerate(sorting_analyzer.unit_ids): + period_starts = np.arange(0, sorting_analyzer.get_num_samples(segment_index), samples_per_period) + periods_per_unit = np.zeros(len(period_starts), dtype=unit_period_dtype) + for i, period_start in enumerate(period_starts): + period_end = min(period_start + samples_per_period, sorting_analyzer.get_num_samples(segment_index)) + periods_per_unit[i]["segment_index"] = segment_index + periods_per_unit[i]["start_sample_index"] = period_start + periods_per_unit[i]["end_sample_index"] = period_end + periods_per_unit[i]["unit_index"] = unit_index + all_periods.append(periods_per_unit) + return np.concatenate(all_periods) + + +def create_ground_truth_pc_distributions(center_locations, total_points): + """ + Simulate PCs as multivariate Gaussians, for testing PC-based quality metrics + Values are created for only one channel and vary along one dimension. + + Parameters + ---------- + center_locations : array-like (units, ) or (channels, units) + Mean of the multivariate gaussian at each channel for each unit. + total_points : array-like + Number of points in each unit distribution. + + Returns + ------- + all_pcs : numpy.ndarray + PC scores for each point. + all_labels : numpy.array + Labels for each point. + """ + from scipy.stats import multivariate_normal + + np.random.seed(0) + + if len(np.array(center_locations).shape) == 1: + distributions = [ + multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) + for center, size in zip(center_locations, total_points) + ] + all_pcs = np.concatenate(distributions, axis=0) + + else: + all_pcs = np.empty((np.sum(total_points), 3, center_locations.shape[0])) + for channel in range(center_locations.shape[0]): + distributions = [ + multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) + for center, size in zip(center_locations[channel], total_points) + ] + all_pcs[:, :, channel] = np.concatenate(distributions, axis=0) + + all_labels = np.concatenate([np.ones((total_points[i],), dtype="int") * i for i in range(len(total_points))]) + + return all_pcs, all_labels