From e0c3368be84c50cf0c2046c05a17e5edd8f873f5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 12 Dec 2025 16:52:36 +0100 Subject: [PATCH 1/2] Add all-in-one run_compare_analyzers function --- spikeinterface_gui/__init__.py | 2 +- spikeinterface_gui/compareunitlistview.py | 209 ++++++ spikeinterface_gui/controller.py | 24 +- spikeinterface_gui/controllercomparison.py | 798 +++++++++++++++++++++ spikeinterface_gui/main.py | 183 ++++- spikeinterface_gui/viewlist.py | 5 +- spikeinterface_gui/waveformheatmapview.py | 1 - spikeinterface_gui/waveformview.py | 6 +- 8 files changed, 1210 insertions(+), 18 deletions(-) create mode 100644 spikeinterface_gui/compareunitlistview.py create mode 100644 spikeinterface_gui/controllercomparison.py diff --git a/spikeinterface_gui/__init__.py b/spikeinterface_gui/__init__.py index 90764da..cc7a9e9 100644 --- a/spikeinterface_gui/__init__.py +++ b/spikeinterface_gui/__init__.py @@ -12,5 +12,5 @@ from .version import version as __version__ -from .main import run_mainwindow, run_launcher +from .main import run_mainwindow, run_launcher, run_mainwindow_comparison diff --git a/spikeinterface_gui/compareunitlistview.py b/spikeinterface_gui/compareunitlistview.py new file mode 100644 index 0000000..df9350e --- /dev/null +++ b/spikeinterface_gui/compareunitlistview.py @@ -0,0 +1,209 @@ +import numpy as np + +import pyqtgraph as pg + +from .view_base import ViewBase + + +class CompareUnitListView(ViewBase): + """ + View for displaying unit comparison between two analyzers. + Shows matched units, their agreement scores, and spike counts. + """ + _supported_backend = ['qt'] + _depend_on = ['comparison'] + _gui_help_txt = "Display comparison table between two sorting outputs" + _settings = [ + {"name": "matching_mode", "type": "list", "value": "hungarian", "options": ["hungarian", "best_match"]}, + ] + + def __init__(self, controller=None, parent=None, backend="qt"): + ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) + self.unit_dtype = self.controller.unit_ids.dtype + + + def _qt_make_layout(self): + from .myqt import QT + + self.layout = QT.QVBoxLayout() + + # Create table widget + self.table = QT.QTableWidget() + self.layout.addWidget(self.table) + + # Setup table + self.table.setSelectionBehavior(QT.QAbstractItemView.SelectRows) + self.table.setSelectionMode(QT.QAbstractItemView.SingleSelection) + self.table.itemSelectionChanged.connect(self._qt_on_selection_changed) + + # Setup table structure + self.table.setColumnCount(5) + self.table.setHorizontalHeaderLabels([ + f'Unit ({self.controller.analyzer1_name})', + f'Unit ({self.controller.analyzer2_name})', + 'Agreement Score', + f'#Spikes ({self.controller.analyzer1_name})', + f'#Spikes ({self.controller.analyzer2_name})' + ]) + self.table.setSortingEnabled(True) + # Sort by Agreement Score column (index 2) by default + self.table.sortItems(2, QT.Qt.DescendingOrder) + self.table.setSelectionMode(QT.QAbstractItemView.SingleSelection) + + def _qt_refresh(self): + """Populate/refresh the comparison table with data""" + from .myqt import QT + + comp = self.controller.comp + + # Get comparison data + if self.settings['matching_mode'] == 'hungarian': + matching_12 = comp.hungarian_match_12 + else: + matching_12 = comp.best_match_12 + matching_12 = comp.hungarian_match_12 + agreement_scores = comp.agreement_scores + + # Get all units from both analyzers + all_units2 = set(self.controller.analyzer2.unit_ids) + + # Build rows: matched pairs + unmatched units + rows = [] + + # Add matched units + for unit1_orig in matching_12.index: + unit2_orig = matching_12[unit1_orig] + if unit2_orig != -1: + # Get combined unit_ids + unit1_idx = list(self.controller.analyzer1.unit_ids).index(unit1_orig) + unit2_idx = list(self.controller.analyzer2.unit_ids).index(unit2_orig) + unit1 = self.controller.unit_ids1[unit1_idx] + unit2 = self.controller.unit_ids2[unit2_idx] + + score = agreement_scores.at[unit1_orig, unit2_orig] + num_spikes1 = self.controller.analyzer1.sorting.get_unit_spike_train(unit1_orig).size + num_spikes2 = self.controller.analyzer2.sorting.get_unit_spike_train(unit2_orig).size + + rows.append( + { + 'unit1': str(unit1), + 'unit2': str(unit2), + 'unit1_orig': unit1_orig, + 'unit2_orig': unit2_orig, + 'agreement_score': f"{score:.3f}", + 'num_spikes1': num_spikes1, + 'num_spikes2': num_spikes2 + } + ) + all_units2.discard(unit2_orig) + else: + # Unmatched unit from analyzer1 + unit1_idx = list(self.controller.analyzer1.unit_ids).index(unit1_orig) + unit1 = self.controller.unit_ids1[unit1_idx] + num_spikes1 = self.controller.analyzer1.sorting.get_unit_spike_train(unit1_orig).size + + rows.append({ + 'unit1': str(unit1), + 'unit2': '', + 'unit1_orig': unit1_orig, + 'unit2_orig': '', + 'agreement_score': '0', + 'num_spikes1': num_spikes1, + 'num_spikes2': 0 + }) + + # Add unmatched units from analyzer2 + print("Remaining unmatched units in analyzer2:", len(all_units2)) + for unit2_orig in all_units2: + unit2_idx = list(self.controller.analyzer2.unit_ids).index(unit2_orig) + unit2 = self.controller.unit_ids2[unit2_idx] + num_spikes2 = self.controller.analyzer2.sorting.get_unit_spike_train(unit2_orig).size + + rows.append({ + 'unit1': '', + 'unit2': str(unit2), + 'unit1_orig': '', + 'unit2_orig': unit2_orig, + 'agreement_score': '', + 'num_spikes1': 0, + 'num_spikes2': num_spikes2 + }) + + # Populate rows + print(len(rows), "rows to display in comparison table") + # Disable sorting while populating + self.table.setSortingEnabled(False) + self.table.setRowCount(len(rows)) + + for i, row in enumerate(rows): + # Unit 1 column with color + if row['unit1'] != '': + item1 = QT.QTableWidgetItem(str(row['unit1'])) + unit1 = np.array([row['unit1']]).astype(self.unit_dtype)[0] + color1 = self.controller.get_unit_color(unit1) + item1.setBackground(QT.QColor(int(color1[0] * 255), int(color1[1] * 255), int(color1[2] * 255))) + # Set text color to white or black depending on background brightness + brightness1 = 0.299 * color1[0] + 0.587 * color1[1] + 0.114 * color1[2] + text_color1 = QT.QColor(255, 255, 255) if brightness1 < 0.5 else QT.QColor(0, 0, 0) + item1.setForeground(text_color1) + else: + item1 = QT.QTableWidgetItem('') + self.table.setItem(i, 0, item1) + + # Unit 2 column with color + if row['unit2'] != '': + item2 = QT.QTableWidgetItem(str(row['unit2'])) + unit2 = np.array([row['unit2']]).astype(self.unit_dtype)[0] + color2 = self.controller.get_unit_color(unit2) + item2.setBackground(QT.QColor(int(color2[0] * 255), int(color2[1] * 255), int(color2[2] * 255))) + # Set text color to white or black depending on background brightness + brightness2 = 0.299 * color2[0] + 0.587 * color2[1] + 0.114 * color2[2] + text_color2 = QT.QColor(255, 255, 255) if brightness2 < 0.5 else QT.QColor(0, 0, 0) + item2.setForeground(text_color2) + else: + item2 = QT.QTableWidgetItem('') + self.table.setItem(i, 1, item2) + + # Other columns + self.table.setItem(i, 2, QT.QTableWidgetItem(row['agreement_score'])) + self.table.setItem(i, 3, QT.QTableWidgetItem(str(row['num_spikes1']))) + self.table.setItem(i, 4, QT.QTableWidgetItem(str(row['num_spikes2']))) + + # Re-enable sorting after populating + self.table.setSortingEnabled(True) + # Resize columns + self.table.resizeColumnsToContents() + + + def _qt_on_selection_changed(self): + """Handle row selection and update unit visibility""" + selected_rows = [] + for item in self.table.selectedItems(): + if item.column() != 1: continue + selected_rows.append(item.row()) + + row_idx = selected_rows[0] + # Get unit values from table items + unit1_item = self.table.item(row_idx, 0) + unit2_item = self.table.item(row_idx, 1) + + # Collect units to make visible + visible_units = [] + + if unit1_item is not None and unit1_item.text() != '': + unit1 = np.array([unit1_item.text()]).astype(self.unit_dtype)[0] + visible_units.append(unit1) + + if unit2_item is not None and unit2_item.text() != '': + unit2 = np.array([unit2_item.text()]).astype(self.unit_dtype)[0] + visible_units.append(unit2) + + print("Selected units:", visible_units) + # Update visibility + if visible_units: + self.controller.set_visible_unit_ids(visible_units) + self.notify_unit_visibility_changed() + + def on_unit_visibility_changed(self): + """Handle external unit visibility changes - could highlight selected row""" + pass \ No newline at end of file diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index c606449..4fb1a3a 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -8,8 +8,6 @@ from spikeinterface.widgets.utils import get_unit_colors from spikeinterface import compute_sparsity from spikeinterface.core import get_template_extremum_channel -import spikeinterface.postprocessing -import spikeinterface.qualitymetrics from spikeinterface.core.sorting_tools import spike_vector_to_indices from spikeinterface.core.core_tools import check_json from spikeinterface.curation import validate_curation_dict @@ -30,6 +28,7 @@ ) from spikeinterface.widgets.sorting_summary import _default_displayed_unit_properties + class Controller(): @@ -269,18 +268,21 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save seg_limits = np.searchsorted(self.spikes["segment_index"], np.arange(num_seg + 1)) self.segment_slices = {segment_index: slice(seg_limits[segment_index], seg_limits[segment_index + 1]) for segment_index in range(num_seg)} - spike_vector2 = self.analyzer.sorting.to_spike_vector(concatenated=False) + spike_vector2 = [] + for segment_index in range(num_seg): + seg_slice = self.segment_slices[segment_index] + spike_vector2.append(self.spikes[seg_slice]) self.final_spike_samples = [segment_spike_vector[-1][0] for segment_spike_vector in spike_vector2] # this is dict of list because per segment spike_indices[segment_index][unit_id] - spike_indices_abs = spike_vector_to_indices(spike_vector2, unit_ids, absolute_index=True) - spike_indices = spike_vector_to_indices(spike_vector2, unit_ids) + spike_indices_abs = spike_vector_to_indices(spike_vector2, self.unit_ids, absolute_index=True) + spike_indices = spike_vector_to_indices(spike_vector2, self.unit_ids) # this is flatten spike_per_seg = [s.size for s in spike_vector2] # dict[unit_id] -> all indices for this unit across segments self._spike_index_by_units = {} # dict[segment_index][unit_id] -> all indices for this unit for one segment self._spike_index_by_segment_and_units = spike_indices_abs - for unit_id in unit_ids: + for unit_id in self.unit_ids: inds = [] for seg_ind in range(num_seg): inds.append(spike_indices[seg_ind][unit_id] + int(np.sum(spike_per_seg[:seg_ind]))) @@ -684,11 +686,11 @@ def get_waveform_sweep(self): def get_waveforms_range(self): return np.nanmin(self.templates_average), np.nanmax(self.templates_average) - def get_waveforms(self, unit_id): - wfs = self.waveforms_ext.get_waveforms_one_unit(unit_id, force_dense=False) - if self.analyzer.sparsity is None: + def get_waveforms(self, unit_id, force_dense=False): + wfs = self.waveforms_ext.get_waveforms_one_unit(unit_id, force_dense=force_dense) + if self.analyzer.sparsity is None or force_dense: # dense waveforms - chan_inds = np.arange(self.analyzer.recording.get_num_channels(), dtype='int64') + chan_inds = np.arange(self.analyzer.get_num_channels(), dtype='int64') else: # sparse waveforms chan_inds = self.analyzer.sparsity.unit_id_to_channel_indices[unit_id] @@ -1026,3 +1028,5 @@ def remove_category_from_unit(self, unit_id, category): elif lbl.get('labels') is not None and category in lbl.get('labels'): lbl['labels'].pop(category) self.curation_data["manual_labels"][ix] = lbl + + diff --git a/spikeinterface_gui/controllercomparison.py b/spikeinterface_gui/controllercomparison.py new file mode 100644 index 0000000..9a4feda --- /dev/null +++ b/spikeinterface_gui/controllercomparison.py @@ -0,0 +1,798 @@ +import time + +import numpy as np +import pandas as pd + + +from spikeinterface.widgets.utils import get_some_colors +from spikeinterface import compute_sparsity +from spikeinterface.core import get_template_extremum_channel +from spikeinterface.core.sorting_tools import spike_vector_to_indices +from spikeinterface.core.recording_tools import get_rec_attributes, do_recording_attributes_match +from spikeinterface.comparison import compare_two_sorters +from spikeinterface.widgets.utils import make_units_table_from_analyzer + + +spike_dtype =[('sample_index', 'int64'), ('unit_index', 'int64'), + ('channel_index', 'int64'), ('segment_index', 'int64'), + ('visible', 'bool'), ('selected', 'bool'), ('rand_selected', 'bool')] + + +_default_main_settings = dict( + max_visible_units=10, + color_mode='color_by_unit', + use_times=False +) + +from spikeinterface.widgets.sorting_summary import _default_displayed_unit_properties + + + +class ControllerComparison(): + def __init__( + self, analyzer1=None, analyzer2=None, + analyzer1_name="1", analyzer2_name="2", + backend="qt", parent=None, verbose=False, with_traces=True, + displayed_unit_properties=None, + extra_unit_properties=None, skip_extensions=None, disable_save_settings_button=False + ): + self.views = [] + skip_extensions = skip_extensions if skip_extensions is not None else [] + + self.skip_extensions = skip_extensions + self.skip_extensions.extend(["principal_components", "correlograms", "isi_histograms", "template_similarity"]) + self.skip_extensions = list(set(self.skip_extensions)) + self.backend = backend + self.disable_save_settings_button = disable_save_settings_button + self.curation = False + # this is not to have a popup when closing + self.current_curation_saved = True + + if self.backend == "qt": + from .backend_qt import SignalHandler + self.signal_handler = SignalHandler(self, parent=parent) + + elif self.backend == "panel": + from .backend_panel import SignalHandler + self.signal_handler = SignalHandler(self, parent=parent) + + self.with_traces = with_traces + + self.analyzer1 = analyzer1 + self.analyzer2 = analyzer2 + self.analyzer1_name = analyzer1_name + self.analyzer2_name = analyzer2_name + assert self.analyzer1.get_extension("random_spikes") is not None + assert self.analyzer2.get_extension("random_spikes") is not None + + assert self.analyzer1.return_in_uV == self.analyzer2.return_in_uV + self.return_in_uV = self.analyzer1.return_in_uV + + # check recording attributes match + recording1 = None + recording2 = None + self.use_recordings = False + + try: + recording1 = self.analyzer1.recording + except: + pass + try: + recording2 = self.analyzer2.recording + except: + pass + if recording1 is not None and recording2 is not None: + match, diff = do_recording_attributes_match( + recording1, get_rec_attributes(recording2) + ) + if match: + self.use_recordings = True + + self.verbose = verbose + t0 = time.perf_counter() + + self.main_settings = _default_main_settings.copy() + + self.num_channels = self.analyzer1.get_num_channels() + # this now private and shoudl be acess using function + self._visible_unit_ids = [self.unit_ids[0]] + + # sparsity1 + if self.analyzer1.sparsity is None: + self.external_sparsity1 = compute_sparsity(self.analyzer1, method="radius",radius_um=90.) + self.analyzer_sparsity1 = None + else: + self.external_sparsity1 = None + self.analyzer_sparsity1 = self.analyzer1.sparsity + # sparsity2 + if self.analyzer2.sparsity is None: + self.external_sparsity2 = compute_sparsity(self.analyzer2, method="radius",radius_um=90.) + self.analyzer_sparsity2 = None + else: + self.external_sparsity2 = None + self.analyzer_sparsity2 = self.analyzer2.sparsity + + + if verbose: + print("Comparing spike sorting outputs") + t0 = time.perf_counter() + self.comp = compare_two_sorters(self.analyzer1.sorting, self.analyzer2.sorting, + sorting1_name=self.analyzer1_name, sorting2_name=self.analyzer2_name) + if verbose: + print("Comparing took", time.perf_counter() - t0) + + # spikes + t0 = time.perf_counter() + if verbose: + print('Gathering all spikes') + self._extremum_channel1 = get_template_extremum_channel(self.analyzer1, peak_sign='neg', outputs='index') + self._extremum_channel2 = get_template_extremum_channel(self.analyzer2, peak_sign='neg', outputs='index') + self._extremum_channel = {} + for unit_id in self.unit_ids: + if unit_id in self.unit_ids1: + extremum_channels = self._extremum_channel1 + else: + extremum_channels = self._extremum_channel2 + self._extremum_channel[unit_id] = extremum_channels[self.get_original_unit_id(unit_id)] + + spike_vector1 = self.analyzer1.sorting.to_spike_vector(concatenated=True, extremum_channel_inds=self._extremum_channel1) + spike_vector2 = self.analyzer2.sorting.to_spike_vector(concatenated=True, extremum_channel_inds=self._extremum_channel2) + + random_spikes_indices1 = self.analyzer1.get_extension("random_spikes").get_data() + random_spikes_indices2 = self.analyzer2.get_extension("random_spikes").get_data() + + self.spikes = np.zeros(spike_vector1.size + spike_vector2.size, dtype=spike_dtype) + self.spikes['sample_index'] = np.concatenate([spike_vector1['sample_index'], spike_vector2['sample_index']]) + self.spikes['unit_index'] = np.concatenate([spike_vector1['unit_index'], spike_vector2['unit_index'] + self.analyzer1.get_num_units()]) + self.spikes['segment_index'] = np.concatenate([spike_vector1['segment_index'], spike_vector2['segment_index']]) + self.spikes['channel_index'] = np.concatenate([spike_vector1['channel_index'], spike_vector2['channel_index']]) + self.spikes['rand_selected'][:] = False + self.spikes['rand_selected'][random_spikes_indices1] = True + self.spikes['rand_selected'][random_spikes_indices2 + spike_vector1.size] = True + + # sort spikes + num_seg = self.analyzer1.get_num_segments() + self.spike_order = np.argsort(self.spikes['sample_index'], kind='stable') + self.spikes = self.spikes[self.spike_order] + + seg_limits = np.searchsorted(self.spikes["segment_index"], np.arange(num_seg + 1)) + self.segment_slices = {segment_index: slice(seg_limits[segment_index], seg_limits[segment_index + 1]) for segment_index in range(num_seg)} + + spike_vector2 = [] + for segment_index in range(num_seg): + seg_slice = self.segment_slices[segment_index] + spike_vector2.append(self.spikes[seg_slice]) + self.final_spike_samples = [segment_spike_vector[-1][0] for segment_spike_vector in spike_vector2] + # this is dict of list because per segment spike_indices[segment_index][unit_id] + spike_indices_abs = spike_vector_to_indices(spike_vector2, self.unit_ids, absolute_index=True) + spike_indices = spike_vector_to_indices(spike_vector2, self.unit_ids) + # this is flatten + spike_per_seg = [s.size for s in spike_vector2] + # dict[unit_id] -> all indices for this unit across segments + self._spike_index_by_units = {} + # dict[segment_index][unit_id] -> all indices for this unit for one segment + self._spike_index_by_segment_and_units = spike_indices_abs + for unit_id in self.unit_ids: + inds = [] + for seg_ind in range(num_seg): + inds.append(spike_indices[seg_ind][unit_id] + int(np.sum(spike_per_seg[:seg_ind]))) + self._spike_index_by_units[unit_id] = np.concatenate(inds) + + t1 = time.perf_counter() + if verbose: + print('Gathering all spikes took', t1 - t0) + + if verbose: + print('Loading extensions') + # Mandatory extensions: computation forced + if verbose: + print('\tLoading templates') + temp_ext1 = self.analyzer1.get_extension("templates") + temp_ext2 = self.analyzer2.get_extension("templates") + assert temp_ext1 is not None and temp_ext2 is not None, "Both analyzers should have 'templates' extension" + self.nbefore, self.nafter = temp_ext1.nbefore, temp_ext1.nafter + + self.templates_average = np.vstack([temp_ext1.get_templates(operator='average'), temp_ext2.get_templates(operator='average')]) + + if 'std' in temp_ext1.params['operators']: + self.templates_std = np.vstack([temp_ext1.get_templates(operator='std'), temp_ext2.get_templates(operator='std')]) + else: + self.templates_std = None + + if verbose: + print('\tLoading unit_locations') + ext1 =self.analyzer1.get_extension('unit_locations') + ext2 = self.analyzer2.get_extension('unit_locations') + assert ext1 is not None and ext2 is not None, "Both analyzers should have 'unit_locations' extension" + self.unit_positions = np.vstack([ext1.get_data()[:, :2], ext2.get_data()[:, :2]]) + + # Optional extensions : can be None or skipped + if verbose: + print('\tLoading noise_levels') + ext1 = self.analyzer1.get_extension('noise_levels') + if ext1 is None and self.has_extension('recording'): + print('Force compute "noise_levels" is needed') + ext1 = self.analyzer1.compute_one_extension('noise_levels') + self.noise_levels = ext1.get_data() if ext1 is not None else None + + if "quality_metrics" in self.skip_extensions: + if self.verbose: + print('\tSkipping quality_metrics') + self.metrics = None + else: + if verbose: + print('\tLoading quality_metrics') + qm_ext1 = self.analyzer1.get_extension('quality_metrics') + qm_ext2 = self.analyzer2.get_extension('quality_metrics') + if qm_ext1 is not None and qm_ext2 is not None: + self.metrics = pd.concat([qm_ext1.get_data(), qm_ext2.get_data()]) + self.metrics.index = self.unit_ids + else: + self.metrics = None + + if "spike_amplitudes" in self.skip_extensions: + if self.verbose: + print('\tSkipping spike_amplitudes') + self.spike_amplitudes = None + else: + if verbose: + print('\tLoading spike_amplitudes') + sa_ext1 = self.analyzer1.get_extension('spike_amplitudes') + sa_ext2 = self.analyzer2.get_extension('spike_amplitudes') + if sa_ext1 is not None and sa_ext2 is not None: + self.spike_amplitudes = np.concatenate([sa_ext1.get_data(), sa_ext2.get_data()])[self.spike_order] + else: + self.spike_amplitudes = None + + if "spike_locations" in self.skip_extensions: + if self.verbose: + print('\tSkipping spike_locations') + self.spike_depths = None + else: + if verbose: + print('\tLoading spike_locations') + sl_ext1 = self.analyzer1.get_extension('spike_locations') + sl_ext2 = self.analyzer2.get_extension('spike_locations') + if sl_ext1 is not None and sl_ext2 is not None: + self.spike_depths = np.concatenate([sl_ext1.get_data()["y"], sl_ext2.get_data()["y"]])[self.spike_order] + else: + self.spike_depths = None + + # Correlograms, ISIs are skipped + self.correlograms, self.correlograms_bins = None, None + self.isi_histograms, self.isi_bins = None, None + + self._similarity_by_method = {} + # if "template_similarity" in self.skip_extensions: + # if self.verbose: + # print('\tSkipping template_similarity') + # else: + # if verbose: + # print('\tLoading template_similarity') + # ts_ext = analyzer.get_extension('template_similarity') + # if ts_ext is not None: + # method = ts_ext.params["method"] + # self._similarity_by_method[method] = ts_ext.get_data() + # else: + # if len(self.unit_ids) <= 64 and len(self.channel_ids) <= 64: + # # precompute similarity when low channel/units count + # method = 'l1' + # ts_ext = analyzer.compute_one_extension('template_similarity', method=method, save=save_on_compute) + # self._similarity_by_method[method] = ts_ext.get_data() + + if "waveforms" in self.skip_extensions: + if self.verbose: + print('\tSkipping waveforms') + self.waveforms_ext1, self.waveforms_ext2 = None, None + else: + if verbose: + print('\tLoading waveforms') + wf_ext1 = self.analyzer1.get_extension('waveforms') + wf_ext2 = self.analyzer2.get_extension('waveforms') + if wf_ext1 is not None and wf_ext2 is not None: + self.waveforms_ext1 = wf_ext1 + self.waveforms_ext2 = wf_ext2 + else: + self.waveforms_ext1, self.waveforms_ext2 = None, None + + self._pc_projections = None + if "principal_components" in self.skip_extensions: + if self.verbose: + print('\tSkipping principal_components') + self.pc_ext1, self.pc_ext2 = None, None + else: + if verbose: + print('\tLoading principal_components') + pc_ext1 = self.analyzer1.get_extension('principal_components') + pc_ext2 = self.analyzer2.get_extension('principal_components') + if pc_ext1 is not None and pc_ext2 is not None: + self.pc_ext1 = pc_ext1 + self.pc_ext2 = pc_ext2 + else: + self.pc_ext1, self.pc_ext2 = None, None + + self._potential_merges = None + + t1 = time.perf_counter() + if verbose: + print('Loading extensions took', t1 - t0) + + t0 = time.perf_counter() + + # some direct attribute + self.num_segments =self.analyzer1.get_num_segments() + self.sampling_frequency =self.analyzer1.sampling_frequency + self.num_spikes =self.analyzer1.sorting.count_num_spikes_per_unit(outputs="dict") + + # spikeinterface handle colors in matplotlib style tuple values in range (0,1) + self.refresh_colors() + + + self._spike_visible_indices = np.array([], dtype='int64') + self._spike_selected_indices = np.array([], dtype='int64') + self.update_visible_spikes() + + self._traces_cached = {} + + unit_tables = [] + for analyzer in [self.analyzer1, self.analyzer2]: + unit_table = make_units_table_from_analyzer(analyzer) + unit_tables.append(unit_table) + self.units_table = pd.concat(unit_tables, ignore_index=True) + self.units_table.index = self.unit_ids + if displayed_unit_properties is None: + displayed_unit_properties = list(_default_displayed_unit_properties) + if extra_unit_properties is not None: + displayed_unit_properties += list(extra_unit_properties.keys()) + displayed_unit_properties = [v for v in displayed_unit_properties if v in self.units_table.columns] + self.displayed_unit_properties = displayed_unit_properties + + # set default time info + self.update_time_info() + + def check_is_view_possible(self, view_name): + from .viewlist import possible_class_views + view_class = possible_class_views[view_name] + if view_class._depend_on is not None: + depencies_ok = all(self.has_extension(k) for k in view_class._depend_on) + if not depencies_ok: + if self.verbose: + print(view_name, 'does not have all dependencies', view_class._depend_on) + return False + return True + + def declare_a_view(self, new_view): + assert new_view not in self.views, 'view already declared {}'.format(self) + self.views.append(new_view) + self.signal_handler.connect_view(new_view) + + @property + def channel_ids(self): + return self.analyzer1.channel_ids + + @property + def unit_ids1(self): + return self.unit_ids[:self.analyzer1.get_num_units()] + + @property + def unit_ids2(self): + return self.unit_ids[self.analyzer1.get_num_units():] + + def get_original_unit_id(self, unit_id): + """Get original unit id from analyzer1 or analyzer2 given combined unit_id""" + unit_index = list(self.unit_ids).index(unit_id) + if unit_index < self.analyzer1.get_num_units(): + return self.analyzer1.unit_ids[unit_index] + else: + return self.analyzer2.unit_ids[unit_index - self.analyzer1.get_num_units()] + + @property + def unit_ids(self): + if isinstance(self.analyzer1.unit_ids[0], np.integer) and isinstance(self.analyzer2.unit_ids[0], np.integer): + return np.concatenate((self.analyzer1.unit_ids, self.analyzer2.unit_ids + max(self.analyzer1.unit_ids) + 1)) + else: + analyzer1_ids = [str(uid) + f"_{self.analyzer1_name}" for uid in self.analyzer1.unit_ids] + analyzer2_ids = [str(uid) + f"_{self.analyzer2_name}" for uid in self.analyzer2.unit_ids] + return np.array(analyzer1_ids + analyzer2_ids) + + def get_time(self): + """ + Returns selected time and segment index + """ + segment_index = self.time_info['segment_index'] + time_by_seg = self.time_info['time_by_seg'] + time = time_by_seg[segment_index] + return time, segment_index + + def set_time(self, time=None, segment_index=None): + """ + Set selected time and segment index. + If time is None, then the current time is used. + If segment_index is None, then the current segment index is used. + """ + if segment_index is not None: + self.time_info['segment_index'] = segment_index + else: + segment_index = self.time_info['segment_index'] + if time is not None: + self.time_info['time_by_seg'][segment_index] = time + + def update_time_info(self): + # set default time info + if self.main_settings["use_times"] and self.has_extension("recording"): + time_by_seg=np.array( + [ + self.analyzer1.recording.get_start_time(segment_index) for segment_index in range(self.num_segments) + ], + dtype="float64" + ) + else: + time_by_seg=np.array([0] * self.num_segments, dtype="float64") + if not hasattr(self, 'time_info'): + self.time_info = dict( + time_by_seg=time_by_seg, + segment_index=0 + ) + else: + self.time_info['time_by_seg'] = time_by_seg + + def get_t_start_t_stop(self): + segment_index = self.time_info["segment_index"] + if self.main_settings["use_times"] and self.has_extension("recording"): + t_start =self.analyzer1.recording.get_start_time(segment_index=segment_index) + t_stop =self.analyzer1.recording.get_end_time(segment_index=segment_index) + return t_start, t_stop + else: + return 0, self.get_num_samples(segment_index) / self.sampling_frequency + + def get_times_chunk(self, segment_index, t1, t2): + ind1, ind2 = self.get_chunk_indices(t1, t2, segment_index) + if self.main_settings["use_times"]: + recording =self.analyzer1.recording + times_chunk = recording.get_times(segment_index=segment_index)[ind1:ind2] + else: + times_chunk = np.arange(ind2 - ind1, dtype='float64') / self.sampling_frequency + max(t1, 0) + return times_chunk + + def get_chunk_indices(self, t1, t2, segment_index): + if self.main_settings["use_times"]: + recording =self.analyzer1.recording + ind1, ind2 = recording.time_to_sample_index([t1, t2], segment_index=segment_index) + else: + t_start = 0.0 + sr = self.sampling_frequency + ind1 = int((t1 - t_start) * sr) + ind2 = int((t2 - t_start) * sr) + + ind1 = max(0, ind1) + ind2 = min(self.get_num_samples(segment_index), ind2) + return ind1, ind2 + + def sample_index_to_time(self, sample_index): + segment_index = self.time_info["segment_index"] + if self.main_settings["use_times"] and self.has_extension("recording"): + time =self.analyzer1.recording.sample_index_to_time(sample_index, segment_index=segment_index) + return time + else: + return sample_index / self.sampling_frequency + + def time_to_sample_index(self, time): + segment_index = self.time_info["segment_index"] + if self.main_settings["use_times"] and self.has_extension("recording"): + time =self.analyzer1.recording.time_to_sample_index(time, segment_index=segment_index) + return time + else: + return int(time * self.sampling_frequency) + + def get_information_txt(self): + nseg =self.analyzer1.get_num_segments() + nchan =self.analyzer1.get_num_channels() + nunits = len(self.unit_ids) + txt = f"{nchan} channels - {nunits} units - {nseg} segments - {self.analyzer1.format}\n" + txt += f"Loaded {len(self.analyzer1.extensions)} extensions" + + return txt + + def refresh_colors(self): + if self.backend == "qt": + self._cached_qcolors = {} + elif self.backend == "panel": + pass + + if self.main_settings['color_mode'] == 'color_by_unit': + self.colors = get_some_colors(self.unit_ids, color_engine='matplotlib', map_name='gist_ncar', + shuffle=True, seed=42) + elif self.main_settings['color_mode'] == 'color_only_visible': + unit_colors = get_some_colors(self.unit_ids, color_engine='matplotlib', map_name='gist_ncar', + shuffle=True, seed=42) + self.colors = {unit_id: (0.3, 0.3, 0.3, 1.) for unit_id in self.unit_ids} + for unit_id in self.get_visible_unit_ids(): + self.colors[unit_id] = unit_colors[unit_id] + elif self.main_settings['color_mode'] == 'color_by_visibility': + self.colors = {unit_id: (0.3, 0.3, 0.3, 1.) for unit_id in self.unit_ids} + import matplotlib.pyplot as plt + cmap = plt.colormaps['tab10'] + for i, unit_id in enumerate(self.get_visible_unit_ids()): + self.colors[unit_id] = cmap(i) + + + def get_unit_color(self, unit_id): + # scalar unit_id -> color html or QtColor + return self.colors[unit_id] + + def get_spike_colors(self, unit_indices): + # array[unit_ind] -> array[color html or QtColor] + colors = np.zeros((unit_indices.size, 4), dtype="uint8") + unit_inds = np.unique(unit_indices) + for unit_ind in unit_inds: + unit_id = self.unit_ids[unit_ind] + mask = unit_indices == unit_ind + colors[mask] = np.array(self.get_unit_color(unit_id)) * 255 + return colors + + + def get_extremum_channel(self, unit_id): + chan_ind = self._extremum_channel[unit_id] + return chan_ind + + # unit visibility zone + def set_visible_unit_ids(self, visible_unit_ids): + """Make visible some units, all other off""" + lim = self.main_settings['max_visible_units'] + if len(visible_unit_ids) > lim: + visible_unit_ids = visible_unit_ids[:lim] + self._visible_unit_ids = list(visible_unit_ids) + + def get_visible_unit_ids(self): + """Get list of visible unit_ids""" + return self._visible_unit_ids + + def get_visible_unit_indices(self): + """Get list of indicies of visible units""" + unit_ids = list(self.unit_ids) + visible_unit_indices = [unit_ids.index(u) for u in self._visible_unit_ids] + return visible_unit_indices + + def set_all_unit_visibility_off(self): + """As in the name""" + self._visible_unit_ids = [] + + def iter_visible_units(self): + """For looping over unit_ind and unit_id""" + visible_unit_indices = self.get_visible_unit_indices() + visible_unit_ids = self._visible_unit_ids + return zip(visible_unit_indices, visible_unit_ids) + + def set_unit_visibility(self, unit_id, state): + """Change the visibility of on unit, other are unchanged""" + if state and not(unit_id in self._visible_unit_ids): + self._visible_unit_ids.append(unit_id) + elif not state and unit_id in self._visible_unit_ids: + self._visible_unit_ids.remove(unit_id) + + def get_unit_visibility(self, unit_id): + """Get thethe visibility of on unit""" + return unit_id in self._visible_unit_ids + + def get_units_visibility_mask(self): + """Get bool mask of visibility""" + mask = np.zeros(self.unit_ids.size, dtype='bool') + mask[self.get_visible_unit_indices()] = True + return mask + + def get_dict_unit_visible(self): + """Construct the visibility dict keys are unit_ids, previous behavior""" + dict_unit_visible = {u:False for u in self.unit_ids} + for u in self.get_visible_unit_ids(): + dict_unit_visible[u] = True + return dict_unit_visible + ## end unit visibility zone + + def update_visible_spikes(self): + inds = [] + for unit_index, unit_id in self.iter_visible_units(): + inds.append(self._spike_index_by_units[unit_id]) + + if len(inds) > 0: + inds = np.concatenate(inds) + inds = np.sort(inds) + else: + inds = np.array([], dtype='int64') + self._spike_visible_indices = inds + + self._spike_selected_indices = np.array([], dtype='int64') + + def get_indices_spike_visible(self): + return self._spike_visible_indices + + def get_indices_spike_selected(self): + return self._spike_selected_indices + + def set_indices_spike_selected(self, inds): + self._spike_selected_indices = np.array(inds) + # reset active split if needed + if len(self._spike_selected_indices) == 1: + # set time info + segment_index = self.spikes['segment_index'][self._spike_selected_indices[0]] + sample_index = self.spikes['sample_index'][self._spike_selected_indices[0]] + self.set_time(time=self.sample_index_to_time(sample_index), segment_index=segment_index) + + def get_spike_indices(self, unit_id, segment_index=None): + if segment_index is None: + # dict[unit_id] -> all indices for this unit across segments + return self._spike_index_by_units[unit_id] + else: + # dict[segment_index][unit_id] -> all indices for this unit for one segment + return self._spike_index_by_segment_and_units[segment_index][unit_id] + + def get_num_samples(self, segment_index): + return self.analyzer1.get_num_samples(segment_index=segment_index) + + def get_traces(self, trace_source='preprocessed', **kargs): + # assert trace_source in ['preprocessed', 'raw'] + assert trace_source in ['preprocessed'] + + cache_key = (kargs.get("segment_index", None), kargs.get("start_frame", None), kargs.get("end_frame", None)) + if cache_key in self._traces_cached: + return self._traces_cached[cache_key] + else: + # check if start_frame and end_frame are a subset interval of a cached one + for cached_key in self._traces_cached.keys(): + cached_seg = cached_key[0] + cached_start = cached_key[1] + cached_end = cached_key[2] + req_seg = kargs.get("segment_index", None) + req_start = kargs.get("start_frame", None) + req_end = kargs.get("end_frame", None) + if cached_seg is not None and req_seg is not None: + if cached_seg != req_seg: + continue + if cached_start is not None and cached_end is not None and req_start is not None and req_end is not None: + if req_start >= cached_start and req_end <= cached_end: + # subset found + traces = self._traces_cached[cached_key] + start_offset = req_start - cached_start + end_offset = req_end - cached_start + return traces[start_offset:end_offset, :] + + if len(self._traces_cached) > 4: + self._traces_cached.pop(list(self._traces_cached.keys())[0]) + + if trace_source == 'preprocessed': + rec = self.analyzer1.recording + elif trace_source == 'raw': + raise NotImplementedError("Raw traces not implemented yet") + # TODO get with parent recording the non process recording + kargs['return_in_uV'] = self.return_in_uV + traces = rec.get_traces(**kargs) + # put in cache for next call + self._traces_cached[cache_key] = traces + return traces + + def get_contact_location(self): + location = self.analyzer1.get_channel_locations() + return location + + def get_waveform_sweep(self): + return self.nbefore, self.nafter + + def get_waveforms_range(self): + return np.nanmin(self.templates_average), np.nanmax(self.templates_average) + + def get_waveforms(self, unit_id, force_dense=False): + if unit_id in self.unit_ids1: + self.waveforms_ext = self.waveforms_ext1 + analyzer = self.analyzer1 + original_unit_id = self.get_original_unit_id(unit_id) + wfs = self.waveforms_ext.get_waveforms_one_unit(original_unit_id, force_dense=force_dense) + if analyzer.sparsity is None or force_dense: + # dense waveforms + chan_inds = np.arange(analyzer.get_num_channels(), dtype='int64') + else: + # sparse waveforms + chan_inds = analyzer.sparsity.unit_id_to_channel_indices[unit_id] + return wfs, chan_inds + + def get_common_sparse_channels(self, unit_ids): + sparsity_mask = self.get_sparsity_mask() + unit_indexes = [list(self.unit_ids).index(u) for u in unit_ids] + chan_inds, = np.nonzero(sparsity_mask[unit_indexes, :].sum(axis=0)) + return chan_inds + + def get_intersect_sparse_channels(self, unit_ids): + sparsity_mask = self.get_sparsity_mask() + unit_indexes = [list(self.unit_ids).index(u) for u in unit_ids] + chan_inds, = np.nonzero(sparsity_mask[unit_indexes, :].sum(axis=0) == len(unit_ids)) + return chan_inds + + def get_probegroup(self): + return self.analyzer1.get_probegroup() + + def set_channel_visibility(self, visible_channel_inds): + self.visible_channel_inds = np.array(visible_channel_inds, copy=True) + + def has_extension(self, extension_name): + if extension_name == 'recording': + return self.use_recordings + elif extension_name == 'comparison': + return True + else: + # extension needs to be loaded + if extension_name in self.skip_extensions: + return False + else: + return extension_name in self.analyzer1.extensions + + def handle_metrics(self): + return self.metrics is not None + + def get_units_table(self): + return self.units_table + + # TODO + def get_all_pcs(self): + return None, None + # if self._pc_projections is None and (self.pc_ext1 is not None and self.pc_ext2 is not None): + # pc_indices = + # for analyzer, pc_ext in zip([self.analyzer1, self.analyzer2], [self.pc_ext1, self.pc_ext2]): + # # make sure pcs are computed + # if pc_ext.get_data() is None: + # analyzer.compute_one_extension('principal_components', save=self.save_on_compute) + # self._pc_projections, self._pc_indices = self.pc_ext1.get_some_projections( + # channel_ids=self.analyzer1.channel_ids, + # unit_ids=self.analyzer1.unit_ids + # ) + + # return self._pc_indices, self._pc_projections + # else: + # return None, None + + def get_sparsity_mask(self): + if self.external_sparsity1 is not None: + return np.vstack([self.external_sparsity1.mask, self.external_sparsity2.mask]) + else: + return np.vstack([self.analyzer_sparsity1.mask, self.analyzer_sparsity2.mask]) + + def get_similarity(self, method=None): + if method is None and len(self._similarity_by_method) == 1: + method = list(self._similarity_by_method.keys())[0] + similarity = self._similarity_by_method.get(method, None) + return similarity + + def compute_similarity(self, method='l1'): + # have internal cache + if method in self._similarity_by_method: + return self._similarity_by_method[method] + ext = self.analyzer.compute("template_similarity", method=method, save=self.save_on_compute) + self._similarity_by_method[method] = ext.get_data() + return self._similarity_by_method[method] + + def compute_unit_positions(self, method, method_kwargs): + unit_positions = np.zeros((len(self.unit_ids), 2), dtype='float32') + for analyzer in [self.analyzer1, self.analyzer2]: + ext = analyzer.get_extension('unit_locations') + ext = analyzer.compute_one_extension('unit_locations', save=self.save_on_compute, method=method, **method_kwargs) + unit_positions[:len(analyzer.unit_ids), :] = ext.get_data()[:, :2] + self.unit_positions = unit_positions + + # def get_correlograms(self): + # return self.correlograms, self.correlograms_bins + + # def compute_correlograms(self, window_ms, bin_ms): + # ext = self.analyzer.compute("correlograms", save=self.save_on_compute, window_ms=window_ms, bin_ms=bin_ms) + # self.correlograms, self.correlograms_bins = ext.get_data() + # return self.correlograms, self.correlograms_bins + + # def get_isi_histograms(self): + # return self.isi_histograms, self.isi_bins + + # def compute_isi_histograms(self, window_ms, bin_ms): + # ext = self.analyzer.compute("isi_histograms", save=self.save_on_compute, window_ms=window_ms, bin_ms=bin_ms) + # self.isi_histograms, self.isi_bins = ext.get_data() + # return self.isi_histograms, self.isi_bins + + def get_units_table(self): + return self.units_table + + def get_split_unit_ids(self): + return [] diff --git a/spikeinterface_gui/main.py b/spikeinterface_gui/main.py index 2f373e4..95bd0a6 100644 --- a/spikeinterface_gui/main.py +++ b/spikeinterface_gui/main.py @@ -93,7 +93,6 @@ def run_mainwindow( disable_save_settings_button: bool, default: False If True, disables the "save default settings" button, so that user cannot do this. """ - if mode == "desktop": backend = "qt" elif mode == "web": @@ -231,6 +230,188 @@ def run_launcher(mode="desktop", analyzer_folders=None, root_folder=None, addres else: raise ValueError(f"spikeinterface-gui wrong mode {mode}") + +def run_mainwindow_comparison( + analyzer1, + analyzer2, + analyzer1_name=None, + analyzer2_name=None, + mode="desktop", + with_traces=True, + displayed_unit_properties=None, + extra_unit_properties=None, + skip_extensions=None, + recording=None, + start_app=True, + layout_preset=None, + layout=None, + address="localhost", + port=0, + panel_start_server_kwargs=None, + panel_window_servable=True, + verbose=False, + user_settings=None, + disable_save_settings_button=False, +): + """ + Create the main window and start the QT app loop. + + Parameters + ---------- + analyzer1: SortingAnalyzer + The first sorting analyzer object + analyzer2: SortingAnalyzer + The second sorting analyzer object + analyzer1_name: str | None, default: None + The name to display for the first analyzer + analyzer2_name: str | None, default: None + The name to display for the second analyzer + mode: 'desktop' | 'web' + The GUI mode to use. + 'desktop' will run a Qt app. + 'web' will run a Panel app. + with_traces: bool, default: True + If True, traces are displayed + curation: bool, default: False + If True, the curation panel is displayed + curation_dict: dict | None, default: None + The curation dictionary to start from an existing curation + label_definitions: dict | None, default: None + The label definitions to provide to the curation panel + displayed_unit_properties: list | None, default: None + The displayed unit properties in the unit table + extra_unit_properties: list | None, default: None + The extra unit properties in the unit table + skip_extensions: list | None, default: None + The list of extensions to skip when loading the sorting analyzer + recording: RecordingExtractor | None, default: None + The recording object to display traces. This can be used when the + SortingAnalyzer is recordingless. + start_qt_app: bool, default: True + If True, the QT app loop is started + layout_preset : str | None + The name of the layout preset. None is default. + layout : dict | None + The layout dictionary to use instead of the preset. + address: str, default : "localhost" + For "web" mode only. By default it is "localhost". + Use "auto-ip" to use the real IP address of the machine. + port: int, default: 0 + For "web" mode only. If 0 then the port is automatic. + panel_start_server_kwargs: dict, default: None + For "web" mode only. Additional arguments to pass to the Panel server + - `{'show': True}` to automatically open the browser (default is True). + - `{'dev': True}` to enable development mode (default is False). + - `{'autoreload': True}` to enable autoreload of the server when files change + (default is False). + panel_window_servable: bool, default: True + For "web" mode only. If True, the Panel app is made servable. + This is useful when embedding the GUI in another Panel app. In that case, + the `panel_window_servable` should be set to False. + verbose: bool, default: False + If True, print some information in the console + user_settings: dict, default: None + A dictionary of user settings for each view, which overwrite the default settings. + disable_save_settings_button: bool, default: False + If True, disables the "save default settings" button, so that user cannot do this. + """ + from .controllercomparison import ControllerComparison + + if mode == "desktop": + backend = "qt" + elif mode == "web": + raise NotImplementedError + else: + raise ValueError(f"spikeinterface-gui wrong mode {mode}") + + # Order of preference for settings is set here: + # 1) User specified settings + # 2) Settings in the config folder + # 3) Default settings of each view + if user_settings is None: + sigui_version = spikeinterface_gui.__version__ + config_version_folder = get_config_folder() / sigui_version + settings_file = config_version_folder / "settings.json" + if settings_file.is_file(): + try: + with open(settings_file) as f: + user_settings = json.load(f) + except json.JSONDecodeError as e: + print(f"Config file at {settings_file} is not decodable. Error: {e}") + print("Using default settings.") + + if recording is not None: + analyzer1.set_temporary_recording(recording) + analyzer2.set_temporary_recording(recording) + + if verbose: + import time + t0 = time.perf_counter() + + views_to_remove = ["merge", "correlograms", "isi"] + + layout_dict = get_layout_description(layout_preset, layout) + if skip_extensions is None: + skip_extensions = find_skippable_extensions(layout_dict) + + for zone in layout_dict: + views_in_zone = layout_dict[zone] + if 'unitlist' in views_in_zone: + # substitute unitlist with compareunitlist + layout_dict[zone] = ['compareunitlist' if view == 'unitlist' else view for view in layout_dict[zone]] + for view in views_to_remove: + if view in layout_dict[zone]: + layout_dict[zone].remove(view) + print(layout_dict) + + controller = ControllerComparison( + analyzer1, analyzer2, analyzer1_name=analyzer1_name, analyzer2_name=analyzer2_name, + backend=backend, verbose=verbose, + with_traces=with_traces, + displayed_unit_properties=displayed_unit_properties, + extra_unit_properties=extra_unit_properties, + skip_extensions=skip_extensions, + disable_save_settings_button=disable_save_settings_button + ) + if verbose: + t1 = time.perf_counter() + print('controller init time', t1 - t0) + + if backend == "qt": + from spikeinterface_gui.myqt import QT, mkQApp + from spikeinterface_gui.backend_qt import QtMainWindow + + # Suppress a known pyqtgraph warning + warnings.filterwarnings("ignore", category=RuntimeWarning, module="pyqtgraph") + warnings.filterwarnings('ignore', category=UserWarning, message=".*QObject::connect.*") + + app = mkQApp() + + win = QtMainWindow(controller, layout_dict=layout_dict, user_settings=user_settings) + win.setWindowTitle('SpikeInterface GUI') + # Set window icon + icon_file = Path(__file__).absolute().parent / 'img' / 'si.png' + if icon_file.exists(): + app.setWindowIcon(QT.QIcon(str(icon_file))) + win.show() + if start_app: + app.exec() + + # elif backend == "panel": + # from .backend_panel import PanelMainWindow, start_server + # win = PanelMainWindow(controller, layout_dict=layout_dict, user_settings=user_settings) + + # if start_app or panel_window_servable: + # win.main_layout.servable(title='SpikeInterface GUI') + + # if start_app: + # panel_start_server_kwargs = panel_start_server_kwargs or {} + # _ = start_server(win, address=address, port=port, **panel_start_server_kwargs) + + return win + + + def check_folder_is_analyzer(folder): """ Check if the given folder is a valid SortingAnalyzer folder. diff --git a/spikeinterface_gui/viewlist.py b/spikeinterface_gui/viewlist.py index c48e0ac..e012dc8 100644 --- a/spikeinterface_gui/viewlist.py +++ b/spikeinterface_gui/viewlist.py @@ -17,6 +17,8 @@ from .metricsview import MetricsView from .spikerateview import SpikeRateView +from .compareunitlistview import CompareUnitListView + # probe and mainsettings view are first, since they affect other views (e.g., time info) possible_class_views = dict( probe = ProbeView, @@ -36,5 +38,6 @@ tracemap = TraceMapView, curation = CurationView, spikerate = SpikeRateView, - metrics = MetricsView, + metrics = MetricsView, + compareunitlist = CompareUnitListView ) diff --git a/spikeinterface_gui/waveformheatmapview.py b/spikeinterface_gui/waveformheatmapview.py index 3006adf..0931318 100644 --- a/spikeinterface_gui/waveformheatmapview.py +++ b/spikeinterface_gui/waveformheatmapview.py @@ -53,7 +53,6 @@ def get_plotting_data(self): waveforms = [] for unit_id in visible_unit_ids: - wfs, channel_inds = self.controller.get_waveforms(unit_id) wfs, chan_inds = self.controller.get_waveforms(unit_id) keep = np.isin(chan_inds, intersect_sparse_indexes) waveforms.append(wfs[:, :, keep]) diff --git a/spikeinterface_gui/waveformview.py b/spikeinterface_gui/waveformview.py index ca17218..81c1b61 100644 --- a/spikeinterface_gui/waveformview.py +++ b/spikeinterface_gui/waveformview.py @@ -540,7 +540,6 @@ def _qt_refresh_with_spikes(self): if num_waveforms <= 0: self.curve_waveforms.setData([], []) return - wf_ext = self.controller.analyzer.get_extension("waveforms") visible_unit_ids = self.controller.get_visible_unit_ids() # Process waveforms per unit to maintain color association @@ -548,7 +547,7 @@ def _qt_refresh_with_spikes(self): width = None for unit_id in visible_unit_ids: - waveforms = wf_ext.get_waveforms_one_unit(unit_id, force_dense=True) + waveforms, _ = self.controller.get_waveforms(unit_id, force_dense=True) if waveforms is None or len(waveforms) == 0: continue @@ -1270,7 +1269,6 @@ def _panel_refresh_waveforms_samples(self): self.lines_data_source_wfs_geom.data = dict(xs=[], ys=[], colors=[]) return - wf_ext = self.controller.analyzer.get_extension("waveforms") visible_unit_ids = self.controller.get_visible_unit_ids() # Process waveforms per unit to maintain color association @@ -1278,7 +1276,7 @@ def _panel_refresh_waveforms_samples(self): width = None for unit_id in visible_unit_ids: - waveforms = wf_ext.get_waveforms_one_unit(unit_id, force_dense=True) + waveforms, _ = self.controller.get_waveforms(unit_id, force_dense=True) if waveforms is None or len(waveforms) == 0: continue From a1e3c0f3d0fc36d4010ff2d99f0799b77a436b3f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 12 Dec 2025 17:08:34 +0100 Subject: [PATCH 2/2] Cleanup --- spikeinterface_gui/compareunitlistview.py | 57 ++++++++++++----------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/spikeinterface_gui/compareunitlistview.py b/spikeinterface_gui/compareunitlistview.py index df9350e..9f3566c 100644 --- a/spikeinterface_gui/compareunitlistview.py +++ b/spikeinterface_gui/compareunitlistview.py @@ -37,13 +37,11 @@ def _qt_make_layout(self): self.table.itemSelectionChanged.connect(self._qt_on_selection_changed) # Setup table structure - self.table.setColumnCount(5) + self.table.setColumnCount(3) self.table.setHorizontalHeaderLabels([ f'Unit ({self.controller.analyzer1_name})', f'Unit ({self.controller.analyzer2_name})', 'Agreement Score', - f'#Spikes ({self.controller.analyzer1_name})', - f'#Spikes ({self.controller.analyzer2_name})' ]) self.table.setSortingEnabled(True) # Sort by Agreement Score column (index 2) by default @@ -113,7 +111,6 @@ def _qt_refresh(self): }) # Add unmatched units from analyzer2 - print("Remaining unmatched units in analyzer2:", len(all_units2)) for unit2_orig in all_units2: unit2_idx = list(self.controller.analyzer2.unit_ids).index(unit2_orig) unit2 = self.controller.unit_ids2[unit2_idx] @@ -138,36 +135,44 @@ def _qt_refresh(self): for i, row in enumerate(rows): # Unit 1 column with color if row['unit1'] != '': - item1 = QT.QTableWidgetItem(str(row['unit1'])) unit1 = np.array([row['unit1']]).astype(self.unit_dtype)[0] - color1 = self.controller.get_unit_color(unit1) - item1.setBackground(QT.QColor(int(color1[0] * 255), int(color1[1] * 255), int(color1[2] * 255))) - # Set text color to white or black depending on background brightness - brightness1 = 0.299 * color1[0] + 0.587 * color1[1] + 0.114 * color1[2] - text_color1 = QT.QColor(255, 255, 255) if brightness1 < 0.5 else QT.QColor(0, 0, 0) - item1.setForeground(text_color1) + n = row['num_spikes1'] + name = f'{unit1} n={n}' + color = self.get_unit_color(unit1) + pix = QT.QPixmap(16, 16) + pix.fill(color) + icon = QT.QIcon(pix) + item1 = QT.QTableWidgetItem(name) + item1.setData(QT.Qt.ItemDataRole.UserRole, unit1) + item1.setFlags(QT.Qt.ItemIsEnabled | QT.Qt.ItemIsSelectable) + item1.setIcon(icon) + item1.unit1 = unit1 else: item1 = QT.QTableWidgetItem('') + item1.unit1 = '' self.table.setItem(i, 0, item1) # Unit 2 column with color if row['unit2'] != '': - item2 = QT.QTableWidgetItem(str(row['unit2'])) unit2 = np.array([row['unit2']]).astype(self.unit_dtype)[0] - color2 = self.controller.get_unit_color(unit2) - item2.setBackground(QT.QColor(int(color2[0] * 255), int(color2[1] * 255), int(color2[2] * 255))) - # Set text color to white or black depending on background brightness - brightness2 = 0.299 * color2[0] + 0.587 * color2[1] + 0.114 * color2[2] - text_color2 = QT.QColor(255, 255, 255) if brightness2 < 0.5 else QT.QColor(0, 0, 0) - item2.setForeground(text_color2) + n = row['num_spikes2'] + name = f'{unit2} n={n}' + color = self.get_unit_color(unit2) + pix = QT.QPixmap(16, 16) + pix.fill(color) + icon = QT.QIcon(pix) + item2 = QT.QTableWidgetItem(name) + item2.setData(QT.Qt.ItemDataRole.UserRole, unit2) + item2.setFlags(QT.Qt.ItemIsEnabled | QT.Qt.ItemIsSelectable) + item2.setIcon(icon) + item2.unit2 = unit2 else: item2 = QT.QTableWidgetItem('') + item2.unit2 = '' self.table.setItem(i, 1, item2) # Other columns self.table.setItem(i, 2, QT.QTableWidgetItem(row['agreement_score'])) - self.table.setItem(i, 3, QT.QTableWidgetItem(str(row['num_spikes1']))) - self.table.setItem(i, 4, QT.QTableWidgetItem(str(row['num_spikes2']))) # Re-enable sorting after populating self.table.setSortingEnabled(True) @@ -186,19 +191,17 @@ def _qt_on_selection_changed(self): # Get unit values from table items unit1_item = self.table.item(row_idx, 0) unit2_item = self.table.item(row_idx, 1) - + unit1 = unit1_item.unit1 + unit2 = unit2_item.unit2 + # Collect units to make visible visible_units = [] - if unit1_item is not None and unit1_item.text() != '': - unit1 = np.array([unit1_item.text()]).astype(self.unit_dtype)[0] + if unit1 != '': visible_units.append(unit1) - - if unit2_item is not None and unit2_item.text() != '': - unit2 = np.array([unit2_item.text()]).astype(self.unit_dtype)[0] + if unit2 != '': visible_units.append(unit2) - print("Selected units:", visible_units) # Update visibility if visible_units: self.controller.set_visible_unit_ids(visible_units)