diff --git a/kilosort/gui/sorter.py b/kilosort/gui/sorter.py index c60140f0..c9a47de7 100644 --- a/kilosort/gui/sorter.py +++ b/kilosort/gui/sorter.py @@ -48,79 +48,92 @@ def run(self): results_dir.mkdir(parents=True) setup_logger(results_dir) - logger.info(f"Kilosort version {kilosort.__version__}") - logger.info(f"Sorting {self.data_path}") - logger.info('-'*40) - - tic0 = time.time() - - # TODO: make these options in GUI - do_CAR=True - invert_sign=False - - if not do_CAR: - logger.info("Skipping common average reference.") - - if probe['chanMap'].max() >= settings['n_chan_bin']: - raise ValueError( - f'Largest value of chanMap exceeds channel count of data, ' - 'make sure chanMap is 0-indexed.' - ) - - if settings['nt0min'] is None: - settings['nt0min'] = int(20 * settings['nt']/61) - data_dtype = settings['data_dtype'] - device = self.device - save_preprocessed_copy = settings['save_preprocessed_copy'] - - ops = initialize_ops(settings, probe, data_dtype, do_CAR, - invert_sign, device, save_preprocessed_copy) - # Remove some stuff that doesn't need to be printed twice, then pretty-print - # format for log file. - ops_copy = ops.copy() - _ = ops_copy.pop('settings') - _ = ops_copy.pop('probe') - print_ops = pprint.pformat(ops_copy, indent=4, sort_dicts=False) - logger.debug(f"Initial ops:\n{print_ops}\n") - - # TODO: add support for file object through data conversion - # Set preprocessing and drift correction parameters - ops = compute_preprocessing(ops, self.device, tic0=tic0, - file_object=self.file_object) - np.random.seed(1) - torch.cuda.manual_seed_all(1) - torch.random.manual_seed(1) - ops, bfile, st0 = compute_drift_correction( - ops, self.device, tic0=tic0, progress_bar=self.progress_bar, - file_object=self.file_object - ) - - # Check scale of data for log file - b1 = bfile.padded_batch_to_torch(0).cpu().numpy() - logger.debug(f"First batch min, max: {b1.min(), b1.max()}") - - if save_preprocessed_copy: - save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile) - - # Will be None if nblocks = 0 (no drift correction) - if st0 is not None: - self.dshift = ops['dshift'] - self.st0 = st0 - self.plotDataReady.emit('drift') - - # Sort spikes and save results - st, tF, Wall0, clu0 = detect_spikes(ops, self.device, bfile, tic0=tic0, - progress_bar=self.progress_bar) - - self.Wall0 = Wall0 - self.wPCA = torch.clone(ops['wPCA'].cpu()).numpy() - self.clu0 = clu0 - self.plotDataReady.emit('diagnostics') - - clu, Wall = cluster_spikes(st, tF, ops, self.device, bfile, tic0=tic0, - progress_bar=self.progress_bar) - ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \ - save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0) + + try: + logger.info(f"Kilosort version {kilosort.__version__}") + logger.info(f"Sorting {self.data_path}") + logger.info('-'*40) + + tic0 = time.time() + + # TODO: make these options in GUI + do_CAR=True + invert_sign=False + + if not do_CAR: + logger.info("Skipping common average reference.") + + if probe['chanMap'].max() >= settings['n_chan_bin']: + raise ValueError( + f'Largest value of chanMap exceeds channel count of data, ' + 'make sure chanMap is 0-indexed.' + ) + + if settings['nt0min'] is None: + settings['nt0min'] = int(20 * settings['nt']/61) + data_dtype = settings['data_dtype'] + device = self.device + save_preprocessed_copy = settings['save_preprocessed_copy'] + + ops = initialize_ops(settings, probe, data_dtype, do_CAR, + invert_sign, device, save_preprocessed_copy) + # Remove some stuff that doesn't need to be printed twice, + # then pretty-print format for log file. + ops_copy = ops.copy() + _ = ops_copy.pop('settings') + _ = ops_copy.pop('probe') + print_ops = pprint.pformat(ops_copy, indent=4, sort_dicts=False) + logger.debug(f"Initial ops:\n{print_ops}\n") + + # TODO: add support for file object through data conversion + # Set preprocessing and drift correction parameters + ops = compute_preprocessing(ops, self.device, tic0=tic0, + file_object=self.file_object) + np.random.seed(1) + torch.cuda.manual_seed_all(1) + torch.random.manual_seed(1) + ops, bfile, st0 = compute_drift_correction( + ops, self.device, tic0=tic0, progress_bar=self.progress_bar, + file_object=self.file_object + ) + + # Check scale of data for log file + b1 = bfile.padded_batch_to_torch(0).cpu().numpy() + logger.debug(f"First batch min, max: {b1.min(), b1.max()}") + + if save_preprocessed_copy: + save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile) + + # Will be None if nblocks = 0 (no drift correction) + if st0 is not None: + self.dshift = ops['dshift'] + self.st0 = st0 + self.plotDataReady.emit('drift') + + # Sort spikes and save results + st, tF, Wall0, clu0 = detect_spikes( + ops, self.device, bfile, tic0=tic0, + progress_bar=self.progress_bar + ) + + self.Wall0 = Wall0 + self.wPCA = torch.clone(ops['wPCA'].cpu()).numpy() + self.clu0 = clu0 + self.plotDataReady.emit('diagnostics') + + clu, Wall = cluster_spikes( + st, tF, ops, self.device, bfile, tic0=tic0, + progress_bar=self.progress_bar + ) + ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \ + save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0) + + except: + # This makes sure the full traceback is written to log file. + logger.exception('Encountered error in `run_kilosort`:') + # Annoyingly, this will print the error message twice for console + # but I haven't found a good way around that. + raise self.ops = ops self.st = st[kept_spikes] diff --git a/kilosort/run_kilosort.py b/kilosort/run_kilosort.py index c3345ea3..381e8631 100644 --- a/kilosort/run_kilosort.py +++ b/kilosort/run_kilosort.py @@ -112,74 +112,82 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None, filename, data_dir, results_dir, probe = \ set_files(settings, filename, probe, probe_name, data_dir, results_dir) setup_logger(results_dir) - logger.info(f"Kilosort version {kilosort.__version__}") - logger.info(f"Sorting {filename}") - logger.info('-'*40) - if data_dtype is None: - logger.info( - "Interpreting binary file as default dtype='int16'. If data was " - "saved in a different format, specify `data_dtype`." + try: + logger.info(f"Kilosort version {kilosort.__version__}") + logger.info(f"Sorting {filename}") + logger.info('-'*40) + + if data_dtype is None: + logger.info( + "Interpreting binary file as default dtype='int16'. If data was " + "saved in a different format, specify `data_dtype`." + ) + data_dtype = 'int16' + + if not do_CAR: + logger.info("Skipping common average reference.") + + if device is None: + if torch.cuda.is_available(): + logger.info('Using GPU for PyTorch computations. ' + 'Specify `device` to change this.') + device = torch.device('cuda') + else: + logger.info('Using CPU for PyTorch computations. ' + 'Specify `device` to change this.') + device = torch.device('cpu') + + if probe['chanMap'].max() >= settings['n_chan_bin']: + raise ValueError( + f'Largest value of chanMap exceeds channel count of data, ' + 'make sure chanMap is 0-indexed.' ) - data_dtype = 'int16' - - if not do_CAR: - logger.info("Skipping common average reference.") - if device is None: - if torch.cuda.is_available(): - logger.info('Using GPU for PyTorch computations. ' - 'Specify `device` to change this.') - device = torch.device('cuda') - else: - logger.info('Using CPU for PyTorch computations. ' - 'Specify `device` to change this.') - device = torch.device('cpu') - - if probe['chanMap'].max() >= settings['n_chan_bin']: - raise ValueError( - f'Largest value of chanMap exceeds channel count of data, ' - 'make sure chanMap is 0-indexed.' - ) - - tic0 = time.time() - ops = initialize_ops(settings, probe, data_dtype, do_CAR, invert_sign, - device, save_preprocessed_copy) - # Remove some stuff that doesn't need to be printed twice, then pretty-print - # format for log file. - ops_copy = ops.copy() - _ = ops_copy.pop('settings') - _ = ops_copy.pop('probe') - print_ops = pprint.pformat(ops_copy, indent=4, sort_dicts=False) - logger.debug(f"Initial ops:\n{print_ops}\n") - - - # Set preprocessing and drift correction parameters - ops = compute_preprocessing(ops, device, tic0=tic0, file_object=file_object) - np.random.seed(1) - torch.cuda.manual_seed_all(1) - torch.random.manual_seed(1) - ops, bfile, st0 = compute_drift_correction( - ops, device, tic0=tic0, progress_bar=progress_bar, - file_object=file_object - ) - - # Check scale of data for log file - b1 = bfile.padded_batch_to_torch(0).cpu().numpy() - logger.debug(f"First batch min, max: {b1.min(), b1.max()}") - - if save_preprocessed_copy: - io.save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile) + tic0 = time.time() + ops = initialize_ops(settings, probe, data_dtype, do_CAR, invert_sign, + device, save_preprocessed_copy) + # Remove some stuff that doesn't need to be printed twice, then pretty-print + # format for log file. + ops_copy = ops.copy() + _ = ops_copy.pop('settings') + _ = ops_copy.pop('probe') + print_ops = pprint.pformat(ops_copy, indent=4, sort_dicts=False) + logger.debug(f"Initial ops:\n{print_ops}\n") + + + # Set preprocessing and drift correction parameters + ops = compute_preprocessing(ops, device, tic0=tic0, file_object=file_object) + np.random.seed(1) + torch.cuda.manual_seed_all(1) + torch.random.manual_seed(1) + ops, bfile, st0 = compute_drift_correction( + ops, device, tic0=tic0, progress_bar=progress_bar, + file_object=file_object + ) - # Sort spikes and save results - st,tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0, - progress_bar=progress_bar) - clu, Wall = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0, - progress_bar=progress_bar) - ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \ - save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, - save_extra_vars=save_extra_vars, - save_preprocessed_copy=save_preprocessed_copy) + # Check scale of data for log file + b1 = bfile.padded_batch_to_torch(0).cpu().numpy() + logger.debug(f"First batch min, max: {b1.min(), b1.max()}") + + if save_preprocessed_copy: + io.save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile) + + # Sort spikes and save results + st,tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0, + progress_bar=progress_bar) + clu, Wall = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0, + progress_bar=progress_bar) + ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \ + save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, + save_extra_vars=save_extra_vars, + save_preprocessed_copy=save_preprocessed_copy) + except: + # This makes sure the full traceback is written to log file. + logger.exception('Encountered error in `run_kilosort`:') + # Annoyingly, this will print the error message twice for console, but + # I haven't found a good way around that. + raise return ops, st, clu, tF, Wall, similar_templates, \ is_ref, est_contam_rate, kept_spikes @@ -435,6 +443,7 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None, Wrapped file object for handling data. """ + tic = time.time() logger.info(' ') logger.info('Computing drift correction.')