Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions src/spikeinterface/benchmark/__init__.py

This file was deleted.

32 changes: 24 additions & 8 deletions src/spikeinterface/benchmark/benchmark_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time


from spikeinterface.core import SortingAnalyzer
from spikeinterface.core import SortingAnalyzer, NumpySorting
from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs
from spikeinterface import load, create_sorting_analyzer, load_sorting_analyzer
from spikeinterface.widgets import get_some_colors
Expand Down Expand Up @@ -118,12 +118,21 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None):
if isinstance(data, tuple):
# old case : rec + sorting
rec, gt_sorting = data
analyzer = create_sorting_analyzer(
gt_sorting, rec, sparse=True, format="binary_folder", folder=local_analyzer_folder
)
analyzer.compute("random_spikes")
analyzer.compute("templates")
analyzer.compute("noise_levels")

if gt_sorting is not None:
analyzer = create_sorting_analyzer(
gt_sorting, rec, sparse=True, format="binary_folder", folder=local_analyzer_folder
)
analyzer.compute("random_spikes")
analyzer.compute("templates")
analyzer.compute("noise_levels")
else:
# some study/benchmark has no GT sorting
# in that case we still need an analyzer for internal API
gt_sorting = NumpySorting.from_samples_and_labels([np.array([])], [np.array([])], rec.sampling_frequency, unit_ids=None)
analyzer = create_sorting_analyzer(
gt_sorting, rec, sparse=False, format="binary_folder", folder=local_analyzer_folder
)
else:
# new case : analzyer
assert isinstance(data, SortingAnalyzer)
Expand Down Expand Up @@ -566,7 +575,10 @@ def _save_keys(self, saved_keys, folder):
elif format == "zarr_templates":
self.result[k].to_zarr(folder / k)
elif format == "sorting_analyzer":
pass
analyzer_folder = folder / k
if analyzer_folder.exists():
shutil.rmtree(analyzer_folder)
self.result[k].save_as(format="binary_folder", folder=analyzer_folder)
else:
raise ValueError(f"Save error {k} {format}")

Expand Down Expand Up @@ -612,6 +624,10 @@ def load_folder(cls, folder):
if zarr_folder.exists():

result[k] = Templates.from_zarr(zarr_folder)
elif format == "sorting_analyzer":
analyzer_folder = folder / k
if analyzer_folder.exists():
result[k] = load_sorting_analyzer(analyzer_folder)

return result

Expand Down
226 changes: 226 additions & 0 deletions src/spikeinterface/benchmark/benchmark_sorter_without_gt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
"""
This replace the previous `GroundTruthStudy`
"""

import numpy as np
from spikeinterface.core import NumpySorting, create_sorting_analyzer
from .benchmark_base import Benchmark, BenchmarkStudy, MixinStudyUnitCount
from spikeinterface.sorters import run_sorter
from spikeinterface.comparison import compare_multiple_sorters

from spikeinterface.benchmark import analyse_residual


# TODO later integrate CollisionGTComparison optionally in this class.


class SorterBenchmarkWithoutGroundTruth(Benchmark):
def __init__(self, recording, gt_sorting, params, sorter_folder):
self.recording = recording
self.gt_sorting = gt_sorting
self.params = params
self.sorter_folder = sorter_folder
self.result = {}

def run(self):
# run one sorter sorter_name is must be in params
raw_sorting = run_sorter(recording=self.recording, folder=self.sorter_folder, **self.params)
sorting = NumpySorting.from_sorting(raw_sorting)
self.result = {"sorting": sorting}

def compute_result(self, residulal_peak_threshold=6, **job_kwargs):

sorting = self.result['sorting']
analyzer = create_sorting_analyzer(
sorting, self.recording, sparse=True, format="memory", **job_kwargs
)
analyzer.compute("random_spikes")
analyzer.compute("templates")
analyzer.compute("noise_levels")
analyzer.compute(
{
"spike_amplitudes" : {},
"amplitude_scalings" : {"handle_collisions": False}
},
**job_kwargs)

analyzer.compute("quality_metrics", **job_kwargs)

residual, peaks = analyse_residual(
analyzer, detect_peaks_kwargs=dict(
method="locally_exclusive",
peak_sign="neg",
detect_threshold=residulal_peak_threshold,
),
**job_kwargs
)

self.result["sorter_analyzer"] = analyzer
self.result["peaks_from_residual"] = peaks

_run_key_saved = [
("sorting", "sorting"),
]
_result_key_saved = [
("multi_comp", "pickle"),
("sorter_analyzer", "sorting_analyzer"),
("peaks_from_residual", "npy"),

]



class SorterStudyWithoutGroundTruth(BenchmarkStudy):
"""
This class is an alternative to SorterStudy when the dataset do not have groundtruth
"""

benchmark_class = SorterBenchmarkWithoutGroundTruth

def create_benchmark(self, key):
dataset_key = self.cases[key]["dataset"]
recording, gt_sorting = self.datasets[dataset_key]
params = self.cases[key]["params"]
sorter_folder = self.folder / "sorters" / self.key_to_str(key)
benchmark = SorterBenchmarkWithoutGroundTruth(recording, gt_sorting, params, sorter_folder)
return benchmark

def _get_comparison_groups(self):
# multicomparison are done on all cases sharing the same dataset key.
case_keys = list(self.cases.keys())
groups = {}
for case_key in case_keys:
data_key = self.cases[case_key]['dataset']
if data_key not in groups:
groups[data_key] = []
groups[data_key].append(case_key)
return groups

def compute_results(self, case_keys=None, verbose=False, delta_time=0.4, match_score=0.5, chance_score=0.1, **result_params):
# Here we need a hack because the results is not computed case by case but all at once

assert case_keys is None, "SorterStudyWithoutGroundTruth do not permit compute_results for sub cases"

# allways the full list
case_keys = list(self.cases.keys())

# First : this do the case by case internally SorterBenchmarkWithoutGroundTruth.compute_result()
BenchmarkStudy.compute_results(self, case_keys=case_keys, verbose=verbose, **result_params)

# Then we need to compute the multicomparison for case that have the same dataset key.
groups = self._get_comparison_groups()

for data_key, group in groups.items():

sorting_list = [self.get_result(key)['sorting'] for key in group]
name_list = [key for key in group]
multi_comp = compare_multiple_sorters(
sorting_list,
name_list=name_list,
delta_time=delta_time,
match_score=0.5,
chance_score=0.1,
agreement_method="count",
n_jobs=-1,
spiketrain_mode="union",
verbose=verbose,
do_matching=True,
)
# and then the same multi comp is stored for each case_key
for key in case_keys:
benchmark = self.benchmarks[key]
benchmark.result['multi_comp'] = multi_comp
benchmark.save_result(self.folder / "results" / self.key_to_str(key))

def plot_residual_peak_amplitudes(self, figsize=None):
import matplotlib.pyplot as plt

groups = self._get_comparison_groups()
colors = self.get_colors()

for data_key, group in groups.items():
fig, ax = plt.subplots(figsize=figsize)

lim0, lim1 = np.inf, -np.inf

for key in group:
peaks = self.get_result(key)["peaks_from_residual"]

lim0 = min(lim0, np.min(peaks["amplitude"]))
lim1 = max(lim1, np.max(peaks["amplitude"]))

bins = np.linspace(lim0, lim1, 200)
if lim1 < 0:
lim1 = 0
if lim0 > 0:
lim0 = 0


for key in group:
peaks = self.get_result(key)["peaks_from_residual"]
print(peaks.size)
print()
count, bins = np.histogram(peaks["amplitude"], bins=bins)
ax.plot(bins[:-1], count, color=colors[key], label=self.cases[key]["label"])

ax.legend()
# def plot_quality_metrics_comparison_on_agreement(self, qm_name='rp_contamination', figsize=None):
# import matplotlib.pyplot as plt

# groups = self._get_comparison_groups()

# for data_key, group in groups.items():
# n = len(group)
# fig, axs = plt.subplots(ncols=n - 1, nrows=n - 1, figsize=figsize, squeeze=False)
# for i, key1 in enumerate(group):
# for j, key2 in enumerate(group):
# if i < j:
# ax = axs[i, j - 1]
# label1 = self.cases[key1]['label']
# label2 = self.cases[key2]['label']

# if i == j - 1:
# ax.set_xlabel(label2)
# ax.set_ylabel(label1)

# multi_comp = self.get_result(key1)['multi_comp']
# comp = multi_comp.comparisons[key1, key2]

# match_12 = comp.hungarian_match_12
# if match_12.dtype.kind =='i':
# mask = match_12.values != -1
# if match_12.dtype.kind =='U':
# mask = match_12.values != ''

# common_unit1_ids = match_12[mask].index
# common_unit2_ids = match_12[mask].values
# metrics1 = self.get_result(key1)["sorter_analyzer"].get_extension("quality_metrics").get_data()
# metrics2 = self.get_result(key2)["sorter_analyzer"].get_extension("quality_metrics").get_data()

# values1 = metrics1.loc[common_unit1_ids, qm_name].values
# values2 = metrics2.loc[common_unit2_ids, qm_name].values

# print(common_unit1_ids, metrics1.columns, values1)
# print(common_unit2_ids, metrics2.columns, values2)

# ax.scatter(values1, values2)
# if i != j - 1:
# ax.set_xlabel("")
# ax.set_ylabel("")
# ax.set_xticks([])
# ax.set_yticks([])
# ax.set_xticklabels([])
# ax.set_yticklabels([])


# def plot_quality_metrics_comparison_on_non_agreement(self, qm_name='rp_contamination', figsize=None):
# import matplotlib.pyplot as plt

# groups = self._get_comparison_groups()

# for data_key, group in groups.items():
# n = len(group)
# fig, ax = plt.subplots(figsize=figsize)
# for key in group:
# pass

75 changes: 75 additions & 0 deletions src/spikeinterface/benchmark/residual_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from spikeinterface.core.generate import InjectTemplatesRecording



def analyse_residual(
analyzer,
detect_peaks_kwargs=dict(
method="locally_exclusive",
peak_sign="both",
detect_threshold=6.,

),
**job_kwargs
):
"""
This create the residual by removing each spike from the recording.
This take in account the spike amplitude scaling, analyzer need "amplitude_scalings" extensions.
Then a peak detector is run on this residual tarces and then number of peaks can be analyzed (the less the better).

This residual is not perfect at the moement because it do not take in the account the jitter per spikes
and so the residual can be high for high amplitude when there is a inherent jitter per spike.

Paramters
----------
analyzer : SortingAnalyzer

Returns
-------
residual : Recording
The resdiual
peaks : np.array
The peaks vector detected on the residual.

"""
from spikeinterface.sortingcomponents.peak_detection import detect_peaks

residual = make_residual_recording(analyzer)

peaks = detect_peaks(residual, **detect_peaks_kwargs, **job_kwargs)

return residual, peaks




def make_residual_recording(analyzer):
"""
This make a lazy recording residual from an anlyzer.

Paramters
----------
analyzer : SortingAnalyzer

Returns
-------
residual : Recording
The resdiual
"""

templates = analyzer.get_extension("templates").get_templates(outputs="Templates")
neg_templates_array = templates.templates_array.copy()
neg_templates_array *= -1

amplitude_factor = analyzer.get_extension("amplitude_scalings").get_data()

residual = InjectTemplatesRecording(
analyzer.sorting,
neg_templates_array,
nbefore=templates.nbefore,
parent_recording=analyzer.recording,
amplitude_factor=amplitude_factor,
)
residual.name = "ResidualRecording"

return residual
Loading
Loading