diff --git a/ecogvis/ecogvis.py b/ecogvis/ecogvis.py index 84f9b3b..00bf5fc 100755 --- a/ecogvis/ecogvis.py +++ b/ecogvis/ecogvis.py @@ -33,6 +33,7 @@ from ecogvis.functions.save_to_nwb import SaveToNWBDialog from ecogvis.functions.nwb_copy_file import nwb_copy_file + annotationAdd_ = False annotationDel_ = False annotationColor_ = 'red' @@ -178,9 +179,6 @@ def init_gui(self): action_del_badchannel = QAction('Del Bad Channel', self) channels_tools_menu.addAction(action_del_badchannel) action_del_badchannel.triggered.connect(self.del_badchannel) - action_save_badchannel = QAction('Save Bad Channels', self) - channels_tools_menu.addAction(action_save_badchannel) - action_save_badchannel.triggered.connect(self.save_badchannel) action_spectral_decomposition = QAction('Spectral Decomposition', self) toolsMenu.addAction(action_spectral_decomposition) action_spectral_decomposition.triggered.connect(self.spectral_decomposition) @@ -251,8 +249,6 @@ def init_gui(self): self.push3_1.clicked.connect(self.add_badchannel) self.push3_2 = QPushButton('Del') self.push3_2.clicked.connect(self.del_badchannel) - self.push3_3 = QPushButton('Save') - self.push3_3.clicked.connect(self.save_badchannel) # Get channel by brain region self.push4_0 = QPushButton('Select regions') @@ -284,7 +280,6 @@ def init_gui(self): grid1.addWidget(qlabelBadChannels, 4, 0, 1, 6) grid1.addWidget(self.push3_1, 5, 0, 1, 2) grid1.addWidget(self.push3_2, 5, 2, 1, 2) - grid1.addWidget(self.push3_3, 5, 4, 1, 2) grid1.addWidget(self.push4_0, 6, 0, 1, 3) grid1.addWidget(self.push5_0, 6, 3, 1, 3) grid1.addWidget(self.push6_0, 7, 0, 1, 3) @@ -576,10 +571,6 @@ def del_badchannel(self): except Exception as ex: print(str(ex)) - def save_badchannel(self): - """Saves bad channels to file.""" - self.model.BadChannelSave() - def spectral_decomposition(self): """Opens Spectral decomposition dialog.""" w = SpectralChoiceDialog(self) @@ -755,7 +746,7 @@ def ChannelSelect(self): self.reset_buttons() # Dialog to choose channels from specific brain regions w = SelectChannelsDialog(self.model.all_regions, self.model.regions_mask) - all_locs = self.model.nwb.electrodes['location'][:] + all_locs = self.model.nwb.electrodes['location'][self.model.electrical_series_channel_ids] self.model.channels_mask = np.zeros(len(all_locs)) for loc in w.choices: self.model.channels_mask += all_locs == np.array(loc) diff --git a/ecogvis/functions/nwb_copy_file.py b/ecogvis/functions/nwb_copy_file.py index 0d0aa77..6b465b9 100644 --- a/ecogvis/functions/nwb_copy_file.py +++ b/ecogvis/functions/nwb_copy_file.py @@ -266,10 +266,10 @@ def copy_obj(obj_old, nwb_old, nwb_new): # ElectricalSeries -------------------------------------------------------- if type(obj_old) is ElectricalSeries: - nChannels = len(obj_old.electrodes.table['x'].data) - elecs_region = nwb_new.electrodes.create_region( + region = np.array(obj_old.electrodes.table.id[:])[obj_old.electrodes.data[:]].tolist() + elecs_region = nwb_new.create_electrode_table_region( name='electrodes', - region=np.arange(nChannels).tolist(), + region=region, description='' ) els = ElectricalSeries( @@ -297,7 +297,6 @@ def copy_obj(obj_old, nwb_old, nwb_new): 'Expected precisely one electrical series, got %i!' % len(obj_old.electrical_series)) els = list(obj_old.electrical_series.values())[0] - nChannels = els.data.shape[1] #### # first check for a table among the new file's data_interfaces @@ -308,9 +307,10 @@ def copy_obj(obj_old, nwb_old, nwb_new): LFP_dynamic_table = nwb_new.electrodes #### + region = np.array(els.electrodes.table.id[:])[els.electrodes.data[:]].tolist() elecs_region = LFP_dynamic_table.create_region( name='electrodes', - region=[i for i in range(nChannels)], + region=region, description=els.electrodes.description ) diff --git a/ecogvis/functions/subFunctions.py b/ecogvis/functions/subFunctions.py index 5f31940..3c58f86 100755 --- a/ecogvis/functions/subFunctions.py +++ b/ecogvis/functions/subFunctions.py @@ -73,15 +73,17 @@ def __init__(self, par, htk_config=None): self.tbin_signal = 1 / self.fs_signal # time bin duration [seconds] self.nBins = self.source.data.shape[0] # total number of bins self.min_window_bins = 10 # minimum number of bins to plot - self.nChTotal = self.source.data.shape[1] # total number of channels - self.allChannels = np.arange(0, self.nChTotal) # array with all channels + # all electricalseries channels ids + self.all_channels_ids = self.source.electrodes.table.id[:] + self.electrical_series_channel_ids = np.array(self.all_channels_ids)[self.source.electrodes.data[:]].tolist() + self.n_channels_total = len(self.electrical_series_channel_ids) # total number of channels # Get Brain regions present in current file - self.all_regions = list(set(list(self.nwb.electrodes['location'][:]))) + self.all_regions = list(set(list(self.nwb.electrodes['location'][self.electrical_series_channel_ids]))) self.all_regions.sort() self.regions_mask = [True] * len(self.all_regions) - self.channels_mask = np.ones(len(self.nwb.electrodes['location'][:])) + self.channels_mask = np.ones(len(self.regions_mask)) self.channels_mask_ind = np.where(self.channels_mask)[0] self.h = [] @@ -93,7 +95,7 @@ def __init__(self, par, htk_config=None): # Channels to show self.firstCh = int(self.parent.qline1.text()) - self.lastCh = min(self.nChTotal, int(self.parent.qline0.text())) + self.lastCh = min(self.n_channels_total, int(self.parent.qline0.text())) self.parent.qline0.setText(str(self.lastCh)) self.nChToShow = self.lastCh - self.firstCh + 1 self.selectedChannels = np.arange(self.firstCh - 1, self.lastCh) @@ -102,9 +104,10 @@ def __init__(self, par, htk_config=None): # List of bad channels if 'bad' in self.nwb.electrodes: - self.badChannels = np.where(self.nwb.electrodes['bad'][:])[0].tolist() + aux_mask = self.nwb.electrodes[self.electrical_series_channel_ids]['bad'] + self.bad_channels_ids = list(self.nwb.electrodes[self.electrical_series_channel_ids][aux_mask].index) else: - self.badChannels = [] + self.bad_channels_ids = [] # Load invalid intervals from NWB file self.allIntervals = [] @@ -249,7 +252,8 @@ def TimeSeries_plotter(self): # Iterate over chosen channels, plot one at a time nrows, ncols = np.shape(plotData) for i in range(nrows): - if self.selectedChannels[i] in self.badChannels: + elec_index = self.source.electrodes[int(self.selectedChannels[i])].index[0] + if elec_index in self.bad_channels_ids: plt2.plot(timebaseGuiUnits, plotData[i], pen=pg.mkPen((220, 0, 0), width=1.2)) else: c = pg.mkPen((0, 120, 0), width=1.2) @@ -258,7 +262,7 @@ def TimeSeries_plotter(self): plt2.plot(timebaseGuiUnits, plotData[i], pen=c) plt2.setLabel('bottom', 'Time', units='sec') plt2.setLabel('left', 'Channel #') - labels = [str(ch + 1) for ch in self.selectedChannels] + labels = [str(self.electrical_series_channel_ids[ch]) for ch in self.selectedChannels] ticks = list(zip(self.scaleVec, labels)) plt2.getAxis('left').setTicks([ticks]) plt2.setXRange(timebaseGuiUnits[0], timebaseGuiUnits[-1], padding=0.003) @@ -414,11 +418,11 @@ def time_scroll(self, scroll=0): def channel_Scroll_Up(self, opt='unit'): """Updates the channels to be plotted. Buttons: ^, ^^ """ # Test upper limit - if self.lastCh < self.nChTotal: + if self.lastCh < self.n_channels_total: if opt == 'unit': step = 1 elif opt == 'page': - step = np.minimum(self.nChToShow, self.nChTotal - self.lastCh) + step = np.minimum(self.nChToShow, self.n_channels_total - self.lastCh) # Add +1 to first and last channels self.firstCh += step self.lastCh += step @@ -454,15 +458,15 @@ def nChannels_Displayed(self): if self.firstCh < 1: self.firstCh = 1 - if self.firstCh > self.nChTotal: - self.firstCh = self.nChTotal - self.lastCh = self.nChTotal + if self.firstCh > self.n_channels_total: + self.firstCh = self.n_channels_total + self.lastCh = self.n_channels_total if self.lastCh < 1: self.lastCh = self.firstCh - if self.lastCh > self.nChTotal: - self.lastCh = self.nChTotal + if self.lastCh > self.n_channels_total: + self.lastCh = self.n_channels_total if self.lastCh - self.firstCh < 1: self.lastCh = self.firstCh @@ -733,10 +737,11 @@ def BadChannelAdd(self, ch_list): ch_list : list of integers List of indices of channels to be marked as 'bad'. """ + # Update list of bad electrodes ids for ch in ch_list: - if ch not in self.badChannels: - self.badChannels.append(ch) - self.refreshScreen() + if ch not in self.bad_channels_ids: + self.bad_channels_ids.append(ch) + self.update_bad_channels() def BadChannelDel(self, ch_list): """ @@ -747,28 +752,32 @@ def BadChannelDel(self, ch_list): ch_list : list of integers List of indices of channels to be un-marked as 'bad'. """ + # Update list of bad electrodes ids for ch in ch_list: - if ch in self.badChannels: - self.badChannels.remove(ch) - self.refreshScreen() + if ch in self.bad_channels_ids: + self.bad_channels_ids.remove(ch) + self.update_bad_channels() + + def update_bad_channels(self): + """Updates list of bad channels after add or del""" + # List of electrodes IDs + elecs_ids = list(self.nwb.electrodes.id[:]) + is_bad_list = [False] * len(elecs_ids) + for i, id in enumerate(elecs_ids): + if id in self.bad_channels_ids: + is_bad_list[i] = True + + if 'bad' not in self.nwb.electrodes: + self.nwb.add_electrode_column( + name='bad', + description='electrode identified as too noisy', + data=is_bad_list, + ) + else: + self.nwb.electrodes['bad'].data[:] = is_bad_list - def BadChannelSave(self): - """Saves list of bad channels in current NWB file.""" - buttonReply = QMessageBox.question(None, ' ', 'Save Bad Channels on current NWB file?', - QMessageBox.No | QMessageBox.Yes) - if buttonReply == QMessageBox.Yes: - # Modify current list of bad channels - aux = [False] * self.nChTotal - for ind in self.badChannels: - aux[ind] = True - if 'bad' not in self.nwb.electrodes: - self.nwb.add_electrode_column( - name='bad', - description='electrode identified as too noisy', - data=aux, - ) - else: - self.nwb.electrodes['bad'].data[:] = aux + # Refresh screen + self.refreshScreen() def DrawMarkTime(self, position): """Marks temporary reference line when adding a new interval."""