diff --git a/atlaselectrophysiology/load_data.py b/atlaselectrophysiology/load_data.py index 05e7e5d..c8906c1 100644 --- a/atlaselectrophysiology/load_data.py +++ b/atlaselectrophysiology/load_data.py @@ -5,6 +5,8 @@ from neuropixel import trace_header import ibllib.atlas as atlas from ibllib.qc.alignment_qc import AlignmentQC +from iblutil.numerical import ismember +from iblutil.util import Bunch from one.api import ONE from one.remote import aws from pathlib import Path @@ -259,16 +261,29 @@ def get_data(self): try: data['spikes'] = self.one.load_object(self.eid, 'spikes', collection=self.probe_collection, attribute=['depths', 'amps', 'times', 'clusters']) - data['spikes']['exists'] = True data['clusters'] = self.one.load_object(self.eid, 'clusters', collection=self.probe_collection, attribute=['metrics', 'peakToTrough', 'waveforms', 'channels']) + + # Remove low firing rate clusters + min_firing_rate = 50. / 3600. + clu_idx = data['clusters'].metrics.firing_rate > min_firing_rate + data['clusters'] = Bunch({k: v[clu_idx] for k, v in data['clusters'].items()}) + spike_idx, ib = ismember(data['spikes'].clusters, data['clusters'].metrics.index) + data['clusters'].metrics.reset_index(drop=True, inplace=True) + data['spikes'] = Bunch({k: v[spike_idx] for k, v in data['spikes'].items()}) + data['spikes'].clusters = data['clusters'].metrics.index[ib].astype(np.int32) + + data['spikes']['exists'] = True data['clusters']['exists'] = True data['channels'] = self.one.load_object(self.eid, 'channels', collection=self.probe_collection, attribute=['rawInd', 'localCoordinates']) data['channels']['exists'] = True + # Set low firing rate clusters to bad + + except alf.exceptions.ALFObjectNotFound: logger.error(f'Could not load spike sorting for probe insertion {self.probe_id}, GUI' f' will not work') diff --git a/atlaselectrophysiology/plot_data.py b/atlaselectrophysiology/plot_data.py index 97043dc..e3dd517 100644 --- a/atlaselectrophysiology/plot_data.py +++ b/atlaselectrophysiology/plot_data.py @@ -1,6 +1,6 @@ from matplotlib import cm import numpy as np -from brainbox.processing import bincount2D +from iblutil.numerical import bincount2D from brainbox.io.spikeglx import Streamer from brainbox.population.decode import xcorr from brainbox.task import passive @@ -596,7 +596,7 @@ def get_autocorr(self, clust_idx): autocorr = xcorr(self.data['spikes']['times'][idx], self.data['spikes']['clusters'][idx], AUTOCORR_BIN_SIZE, AUTOCORR_WIN_SIZE) - return autocorr[0, 0, :], self.clust_id[clust_idx] + return autocorr[0, 0, :], self.data['clusters'].metrics.cluster_id[self.clust_id[clust_idx]] def get_template_wf(self, clust_idx): template_wf = (self.data['clusters']['waveforms'][self.clust_id[clust_idx], :, 0])