diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 9543cbf734..627e6c2af9 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -83,8 +83,8 @@ def __init__( templates = ext_templates.get_templates(unit_ids=unit_ids) bin_min = np.min(templates) * 1.3 bin_max = np.max(templates) * 1.3 - bin_size = (bin_max - bin_min) / 100 - bins = np.arange(bin_min, bin_max, bin_size) + num_bins = 100 + bins = np.linspace(bin_min, bin_max, num_bins + 1) # 2d histograms if same_axis: @@ -121,14 +121,9 @@ def __init__( wfs = wfs_ # make histogram density - wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1) - hist2d = np.zeros((wfs_flat.shape[1], bins.size)) - indexes0 = np.arange(wfs_flat.shape[1]) - - wf_bined = np.floor((wfs_flat - bin_min) / bin_size).astype("int32") - wf_bined = wf_bined.clip(0, bins.size - 1) - for d in wf_bined: - hist2d[indexes0, d] += 1 + wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1) # num_spikes x times*num_channels + hists_per_timepoint = [np.histogram(one_timepoint, bins=bins)[0] for one_timepoint in wfs_flat.T] + hist2d = np.stack(hists_per_timepoint) if same_axis: if all_hist2d is None: @@ -169,7 +164,6 @@ def __init__( BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot)