Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion spikeinterface_gui/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@

from .version import version as __version__

from .main import run_mainwindow, run_launcher
from .main import run_mainwindow, run_launcher, run_mainwindow_comparison

212 changes: 212 additions & 0 deletions spikeinterface_gui/compareunitlistview.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import numpy as np

import pyqtgraph as pg

from .view_base import ViewBase


class CompareUnitListView(ViewBase):
"""
View for displaying unit comparison between two analyzers.
Shows matched units, their agreement scores, and spike counts.
"""
_supported_backend = ['qt']
_depend_on = ['comparison']
_gui_help_txt = "Display comparison table between two sorting outputs"
_settings = [
{"name": "matching_mode", "type": "list", "value": "hungarian", "options": ["hungarian", "best_match"]},
]

def __init__(self, controller=None, parent=None, backend="qt"):
ViewBase.__init__(self, controller=controller, parent=parent, backend=backend)
self.unit_dtype = self.controller.unit_ids.dtype


def _qt_make_layout(self):
from .myqt import QT

self.layout = QT.QVBoxLayout()

# Create table widget
self.table = QT.QTableWidget()
self.layout.addWidget(self.table)

# Setup table
self.table.setSelectionBehavior(QT.QAbstractItemView.SelectRows)
self.table.setSelectionMode(QT.QAbstractItemView.SingleSelection)
self.table.itemSelectionChanged.connect(self._qt_on_selection_changed)

# Setup table structure
self.table.setColumnCount(3)
self.table.setHorizontalHeaderLabels([
f'Unit ({self.controller.analyzer1_name})',
f'Unit ({self.controller.analyzer2_name})',
'Agreement Score',
])
self.table.setSortingEnabled(True)
# Sort by Agreement Score column (index 2) by default
self.table.sortItems(2, QT.Qt.DescendingOrder)
self.table.setSelectionMode(QT.QAbstractItemView.SingleSelection)

def _qt_refresh(self):
"""Populate/refresh the comparison table with data"""
from .myqt import QT

comp = self.controller.comp

# Get comparison data
if self.settings['matching_mode'] == 'hungarian':
matching_12 = comp.hungarian_match_12
else:
matching_12 = comp.best_match_12
matching_12 = comp.hungarian_match_12
agreement_scores = comp.agreement_scores

# Get all units from both analyzers
all_units2 = set(self.controller.analyzer2.unit_ids)

# Build rows: matched pairs + unmatched units
rows = []

# Add matched units
for unit1_orig in matching_12.index:
unit2_orig = matching_12[unit1_orig]
if unit2_orig != -1:
# Get combined unit_ids
unit1_idx = list(self.controller.analyzer1.unit_ids).index(unit1_orig)
unit2_idx = list(self.controller.analyzer2.unit_ids).index(unit2_orig)
unit1 = self.controller.unit_ids1[unit1_idx]
unit2 = self.controller.unit_ids2[unit2_idx]

score = agreement_scores.at[unit1_orig, unit2_orig]
num_spikes1 = self.controller.analyzer1.sorting.get_unit_spike_train(unit1_orig).size
num_spikes2 = self.controller.analyzer2.sorting.get_unit_spike_train(unit2_orig).size

rows.append(
{
'unit1': str(unit1),
'unit2': str(unit2),
'unit1_orig': unit1_orig,
'unit2_orig': unit2_orig,
'agreement_score': f"{score:.3f}",
'num_spikes1': num_spikes1,
'num_spikes2': num_spikes2
}
)
all_units2.discard(unit2_orig)
else:
# Unmatched unit from analyzer1
unit1_idx = list(self.controller.analyzer1.unit_ids).index(unit1_orig)
unit1 = self.controller.unit_ids1[unit1_idx]
num_spikes1 = self.controller.analyzer1.sorting.get_unit_spike_train(unit1_orig).size

rows.append({
'unit1': str(unit1),
'unit2': '',
'unit1_orig': unit1_orig,
'unit2_orig': '',
'agreement_score': '0',
'num_spikes1': num_spikes1,
'num_spikes2': 0
})

# Add unmatched units from analyzer2
for unit2_orig in all_units2:
unit2_idx = list(self.controller.analyzer2.unit_ids).index(unit2_orig)
unit2 = self.controller.unit_ids2[unit2_idx]
num_spikes2 = self.controller.analyzer2.sorting.get_unit_spike_train(unit2_orig).size

rows.append({
'unit1': '',
'unit2': str(unit2),
'unit1_orig': '',
'unit2_orig': unit2_orig,
'agreement_score': '',
'num_spikes1': 0,
'num_spikes2': num_spikes2
})

# Populate rows
print(len(rows), "rows to display in comparison table")
# Disable sorting while populating
self.table.setSortingEnabled(False)
self.table.setRowCount(len(rows))

for i, row in enumerate(rows):
# Unit 1 column with color
if row['unit1'] != '':
unit1 = np.array([row['unit1']]).astype(self.unit_dtype)[0]
n = row['num_spikes1']
name = f'{unit1} n={n}'
color = self.get_unit_color(unit1)
pix = QT.QPixmap(16, 16)
pix.fill(color)
icon = QT.QIcon(pix)
item1 = QT.QTableWidgetItem(name)
item1.setData(QT.Qt.ItemDataRole.UserRole, unit1)
item1.setFlags(QT.Qt.ItemIsEnabled | QT.Qt.ItemIsSelectable)
item1.setIcon(icon)
item1.unit1 = unit1
else:
item1 = QT.QTableWidgetItem('')
item1.unit1 = ''
self.table.setItem(i, 0, item1)

# Unit 2 column with color
if row['unit2'] != '':
unit2 = np.array([row['unit2']]).astype(self.unit_dtype)[0]
n = row['num_spikes2']
name = f'{unit2} n={n}'
color = self.get_unit_color(unit2)
pix = QT.QPixmap(16, 16)
pix.fill(color)
icon = QT.QIcon(pix)
item2 = QT.QTableWidgetItem(name)
item2.setData(QT.Qt.ItemDataRole.UserRole, unit2)
item2.setFlags(QT.Qt.ItemIsEnabled | QT.Qt.ItemIsSelectable)
item2.setIcon(icon)
item2.unit2 = unit2
else:
item2 = QT.QTableWidgetItem('')
item2.unit2 = ''
self.table.setItem(i, 1, item2)

# Other columns
self.table.setItem(i, 2, QT.QTableWidgetItem(row['agreement_score']))

# Re-enable sorting after populating
self.table.setSortingEnabled(True)
# Resize columns
self.table.resizeColumnsToContents()


def _qt_on_selection_changed(self):
"""Handle row selection and update unit visibility"""
selected_rows = []
for item in self.table.selectedItems():
if item.column() != 1: continue
selected_rows.append(item.row())

row_idx = selected_rows[0]
# Get unit values from table items
unit1_item = self.table.item(row_idx, 0)
unit2_item = self.table.item(row_idx, 1)
unit1 = unit1_item.unit1
unit2 = unit2_item.unit2

# Collect units to make visible
visible_units = []

if unit1 != '':
visible_units.append(unit1)
if unit2 != '':
visible_units.append(unit2)

# Update visibility
if visible_units:
self.controller.set_visible_unit_ids(visible_units)
self.notify_unit_visibility_changed()

def on_unit_visibility_changed(self):
"""Handle external unit visibility changes - could highlight selected row"""
pass
24 changes: 14 additions & 10 deletions spikeinterface_gui/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from spikeinterface.widgets.utils import get_unit_colors
from spikeinterface import compute_sparsity
from spikeinterface.core import get_template_extremum_channel
import spikeinterface.postprocessing
import spikeinterface.qualitymetrics
from spikeinterface.core.sorting_tools import spike_vector_to_indices
from spikeinterface.core.core_tools import check_json
from spikeinterface.curation import validate_curation_dict
Expand All @@ -30,6 +28,7 @@
)

from spikeinterface.widgets.sorting_summary import _default_displayed_unit_properties



class Controller():
Expand Down Expand Up @@ -269,18 +268,21 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
seg_limits = np.searchsorted(self.spikes["segment_index"], np.arange(num_seg + 1))
self.segment_slices = {segment_index: slice(seg_limits[segment_index], seg_limits[segment_index + 1]) for segment_index in range(num_seg)}

spike_vector2 = self.analyzer.sorting.to_spike_vector(concatenated=False)
spike_vector2 = []
for segment_index in range(num_seg):
seg_slice = self.segment_slices[segment_index]
spike_vector2.append(self.spikes[seg_slice])
self.final_spike_samples = [segment_spike_vector[-1][0] for segment_spike_vector in spike_vector2]
# this is dict of list because per segment spike_indices[segment_index][unit_id]
spike_indices_abs = spike_vector_to_indices(spike_vector2, unit_ids, absolute_index=True)
spike_indices = spike_vector_to_indices(spike_vector2, unit_ids)
spike_indices_abs = spike_vector_to_indices(spike_vector2, self.unit_ids, absolute_index=True)
spike_indices = spike_vector_to_indices(spike_vector2, self.unit_ids)
# this is flatten
spike_per_seg = [s.size for s in spike_vector2]
# dict[unit_id] -> all indices for this unit across segments
self._spike_index_by_units = {}
# dict[segment_index][unit_id] -> all indices for this unit for one segment
self._spike_index_by_segment_and_units = spike_indices_abs
for unit_id in unit_ids:
for unit_id in self.unit_ids:
inds = []
for seg_ind in range(num_seg):
inds.append(spike_indices[seg_ind][unit_id] + int(np.sum(spike_per_seg[:seg_ind])))
Expand Down Expand Up @@ -684,11 +686,11 @@ def get_waveform_sweep(self):
def get_waveforms_range(self):
return np.nanmin(self.templates_average), np.nanmax(self.templates_average)

def get_waveforms(self, unit_id):
wfs = self.waveforms_ext.get_waveforms_one_unit(unit_id, force_dense=False)
if self.analyzer.sparsity is None:
def get_waveforms(self, unit_id, force_dense=False):
wfs = self.waveforms_ext.get_waveforms_one_unit(unit_id, force_dense=force_dense)
if self.analyzer.sparsity is None or force_dense:
# dense waveforms
chan_inds = np.arange(self.analyzer.recording.get_num_channels(), dtype='int64')
chan_inds = np.arange(self.analyzer.get_num_channels(), dtype='int64')
else:
# sparse waveforms
chan_inds = self.analyzer.sparsity.unit_id_to_channel_indices[unit_id]
Expand Down Expand Up @@ -1026,3 +1028,5 @@ def remove_category_from_unit(self, unit_id, category):
elif lbl.get('labels') is not None and category in lbl.get('labels'):
lbl['labels'].pop(category)
self.curation_data["manual_labels"][ix] = lbl


Loading