From a84e54caa35ad679765e056fda0a719854719c73 Mon Sep 17 00:00:00 2001 From: owinter Date: Wed, 21 Aug 2024 14:36:35 +0100 Subject: [PATCH] add features to the raster viewer --- atlasview/atlasview.py | 8 ++--- viewspikes/gui.py | 69 +++++++++++++++++++++++++++++++----------- 2 files changed, 56 insertions(+), 21 deletions(-) diff --git a/atlasview/atlasview.py b/atlasview/atlasview.py index 79d402f7..32bd7ae9 100644 --- a/atlasview/atlasview.py +++ b/atlasview/atlasview.py @@ -261,10 +261,10 @@ class ControllerTopView(PgImageController): """ TopView ControllerTopView """ - def __init__(self, qmain: TopView, res: int = 25, volume='image', **kwargs): + def __init__(self, qmain: TopView, res: int = 25, volume='image', atlas=None, **kwargs): super(ControllerTopView, self).__init__(qmain) self.volume = volume - self.atlas = AllenAtlas(res) + self.atlas = AllenAtlas(res) if atlas is None else atlas self.fig_top = self.qwidget = qmain # Setup Coronal slice: width: ml, height: dv, depth: ap self.fig_coronal = SliceView(qmain, waxis=0, haxis=2, daxis=1) @@ -362,10 +362,10 @@ class ImageLayer: slice_kwargs: dict = field(default_factory=lambda: {'volume': 'image', 'mode': 'clip'}) -def view(res=25, title=None, brainmap='Allen'): +def view(res=25, title=None, atlas=None): """ application entry point """ qt.create_app() - av = TopView._get_or_create(title=title, res=res, brainmap=brainmap) + av = TopView._get_or_create(title=title, res=res, atlas=atlas) av.show() return av diff --git a/viewspikes/gui.py b/viewspikes/gui.py index 32313bf6..fa5f78d7 100644 --- a/viewspikes/gui.py +++ b/viewspikes/gui.py @@ -7,7 +7,7 @@ from iblutil.numerical import bincount2D from viewephys.gui import viewephys -import neurodsp +import ibldsp from brainbox.io.one import EphysSessionLoader, SpikeSortingLoader from iblatlas.atlas import BrainRegions @@ -27,19 +27,35 @@ (0.7372549019607844, 0.7411764705882353, 0.13333333333333333), (0.09019607843137255, 0.7450980392156863, 0.8117647058823529)] -YMAX = 4000 +YMIN, YMAX = (-1, 4000) + + +def get_trial_events_to_display(trials): + errors = trials['feedback_times'][trials['feedbackType'] == -1].values + errors = np.sort(np.r_[errors, errors + .5]) + gocue = trials['goCue_times'].values + gocue = np.sort(np.r_[gocue, gocue + .11]) + trial_events = dict( + goCue_times=gocue, + error_times=errors, + reward_times=trials['feedback_times'][trials['feedbackType'] == 1].values) + return trial_events def view_raster(pid, one, stream=True): + from qt import create_app + app = create_app() ssl = SpikeSortingLoader(one=one, pid=pid) sl = EphysSessionLoader(one=one, eid=ssl.eid) sl.load_trials() spikes, clusters, channels = ssl.load_spike_sorting(dataset_types=['spikes.samples']) - + clusters = ssl.merge_clusters(spikes, clusters, channels) return RasterView(ssl, spikes, clusters, channels, trials=sl.trials, stream=stream) class RasterView(QtWidgets.QMainWindow): + plotItem_raster: pg.PlotWidget = None + def __init__(self, ssl, spikes, clusters, channels=None, trials=None, stream=True, *args, **kwargs): self.ssl = ssl self.spikes = spikes @@ -81,18 +97,14 @@ def __init__(self, ssl, spikes, clusters, channels=None, trials=None, stream=Tru self.plotItem_raster.addItem(self.line_eqc) ################################################## plot trials if self.trials is not None: - trial_times = dict( - goCue_times=trials['goCue_times'].values, - error_times=trials['feedback_times'][trials['feedbackType'] == -1].values, - reward_times=trials['feedback_times'][trials['feedbackType'] == 1].values) + trial_times = get_trial_events_to_display(trials) self.trial_lines = {} for i, k in enumerate(trial_times): self.trial_lines[k] = pg.PlotCurveItem() self.plotItem_raster.addItem(self.trial_lines[k]) x = np.tile(trial_times[k][:, np.newaxis], (1, 2)).flatten() - y = np.tile(np.array([0, 1, 1, 0]), int(trial_times[k].shape[0] / 2 + 1))[ - :trial_times[k].shape[0] * 2] * YMAX - self.trial_lines[k].setData(x=x.flatten(), y=y.flatten(), pen=pg.mkPen(np.array(SNS_PALETTE[i]) * 256)) + y = np.tile(np.array([YMIN, YMAX, YMAX, YMIN]), int(trial_times[k].shape[0] / 2 + 1))[:trial_times[k].shape[0] * 2] + self.trial_lines[k].setData(x=x.flatten(), y=y.flatten(), pen=pg.mkPen(np.array(SNS_PALETTE[i]) * 255, width=2)) self.show() def mouseClick(self, event): @@ -101,7 +113,7 @@ def mouseClick(self, event): return qxy = self.imageItem_raster.mapFromScene(event.scenePos()) x = qxy.x() - self.show_ephys(t0=self.rtimes[int(x - .5)]) + self.show_ephys(t0=self.rtimes[int(x - 1)]) ymax = np.max(self.depths) + 50 self.line_eqc.setData(x=x + np.array([-.5, -.5, .5, .5]), y=np.array([0, ymax, ymax, 0]), @@ -124,8 +136,13 @@ def keyPressEvent(self, e): m == QtCore.Qt.ControlModifier and k == QtCore.Qt.Key_Z): self.imageItem_raster.setLevels([0, self.imageItem_raster.levels[1] * 1.4]) - def show_ephys(self, t0, tlen=1): - + def show_ephys(self, t0, tlen=1.8): + """ + :param t0: behaviour time in seconds at which to start the view + :param tlen: + :return: + """ + print(t0) s0 = int(self.ssl.samples2times(t0, direction='reverse')) s1 = s0 + int(self.sr.fs * tlen) raw = self.sr[s0:s1, : - self.sr.nsync].T @@ -134,7 +151,7 @@ def show_ephys(self, t0, tlen=1): sos = scipy.signal.butter(**butter_kwargs, output='sos') butt = scipy.signal.sosfiltfilt(sos, raw) - destripe = neurodsp.voltage.destripe(raw, fs=self.sr.fs) + destripe = ibldsp.voltage.destripe(raw, fs=self.sr.fs) self.eqc_raw = viewephys(butt, self.sr.fs, channels=self.channels, br=regions, title='butt', t0=t0, t_scalar=1) self.eqc_des = viewephys(destripe, self.sr.fs, channels=self.channels, br=regions, title='destripe', t0=t0, t_scalar=1) @@ -146,6 +163,24 @@ def show_ephys(self, t0, tlen=1): # we slice the spikes using the samples according to ephys time, but display in session times slice_spikes = slice(np.searchsorted(self.spikes['samples'], s0), np.searchsorted(self.spikes['samples'], s1)) t = self.spikes['times'][slice_spikes] - c = self.clusters.channels[self.spikes.clusters[slice_spikes]] - self.eqc_raw.ctrl.add_scatter(t, c) - self.eqc_des.ctrl.add_scatter(t, c) + ic = self.spikes.clusters[slice_spikes] + + iok = self.clusters['label'][ic] == 1 + + for eqc in [self.eqc_des, self.eqc_raw]: + eqc.ctrl.add_scatter(t[~iok], self.clusters.channels[ic[~iok]], (255, 0, 0, 100), label='bad units') + eqc.ctrl.add_scatter(t[iok], self.clusters.channels[ic[iok]], rgb=(0, 255, 0, 100), label='good units') + + if self.trials is not None: + trial_events = get_trial_events_to_display(self.trials) + for i, k in enumerate(trial_events): + ie = np.logical_and(trial_events[k] >= t0, trial_events[k] <= (t0 + tlen)) + if np.sum(ie) == 0: + continue + te = trial_events[k][ie] + x = np.tile(te[:, np.newaxis], (1, 2)).flatten() + y = np.tile(np.array([YMIN, YMAX, YMAX, YMIN]), int(te.shape[0] / 2 + 1))[:te.shape[0] * 2] + for eqc in [self.eqc_des, self.eqc_raw]: + eqc.ctrl.add_curve(x, y, rgb=(np.array(SNS_PALETTE[i]) * 255).astype(int), label=k) + print(te) +