diff --git a/kilosort/clustering_qr.py b/kilosort/clustering_qr.py index 78cc9dc4..374d9098 100644 --- a/kilosort/clustering_qr.py +++ b/kilosort/clustering_qr.py @@ -1,3 +1,5 @@ +import gc + import numpy as np import torch from torch import sparse_coo_tensor as coo @@ -301,7 +303,8 @@ def y_centers(ops): return centers -def run(ops, st, tF, mode = 'template', device=torch.device('cuda'), progress_bar=None): +def run(ops, st, tF, mode = 'template', device=torch.device('cuda'), + progress_bar=None, clear_cache=False): if mode == 'template': xy, iC = xy_templates(ops) @@ -362,11 +365,16 @@ def run(ops, st, tF, mode = 'template', device=torch.device('cuda'), progress_b # find new clusters iclust, iclust0, M, iclust_init = cluster(Xd, nskip=nskip, lam=1, - seed=5, device=device) + seed=5, device=device) + if clear_cache: + gc.collect() + torch.cuda.empty_cache() xtree, tstat, my_clus = hierarchical.maketree(M, iclust, iclust0) - xtree, tstat = swarmsplitter.split(Xd.numpy(), xtree, tstat, iclust, my_clus, meta = st0) + xtree, tstat = swarmsplitter.split( + Xd.numpy(), xtree, tstat,iclust, my_clus, meta=st0 + ) iclust = swarmsplitter.new_clusters(iclust, my_clus, xtree, tstat) diff --git a/kilosort/datashift.py b/kilosort/datashift.py index 52d4549b..30b9da79 100644 --- a/kilosort/datashift.py +++ b/kilosort/datashift.py @@ -183,7 +183,8 @@ def kernel2D(x, y, sig = 1): Kn = np.exp(-ds / (2*sig**2)) return Kn -def run(ops, bfile, device=torch.device('cuda'), progress_bar=None): +def run(ops, bfile, device=torch.device('cuda'), progress_bar=None, + clear_cache=False): """ this step computes a drift correction model it returns vertical correction amplitudes for each batch, and for multiple blocks in a batch if nblocks > 1. """ @@ -194,7 +195,10 @@ def run(ops, bfile, device=torch.device('cuda'), progress_bar=None): return ops, None # the first step is to extract all spikes using the universal templates - st, _, ops = spikedetect.run(ops, bfile, device=device, progress_bar=progress_bar) + st, _, ops = spikedetect.run( + ops, bfile, device=device, progress_bar=progress_bar, + clear_cache=clear_cache + ) # spikes are binned by amplitude and y-position to construct a "fingerprint" for each batch F, ysamp = bin_spikes(ops, st) diff --git a/kilosort/run_kilosort.py b/kilosort/run_kilosort.py index 28d1b727..a136807e 100644 --- a/kilosort/run_kilosort.py +++ b/kilosort/run_kilosort.py @@ -26,7 +26,7 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None, data_dir=None, file_object=None, results_dir=None, data_dtype=None, do_CAR=True, invert_sign=False, device=None, - progress_bar=None, save_extra_vars=False, + progress_bar=None, save_extra_vars=False, clear_cache=False, save_preprocessed_copy=False, bad_channels=None): """Run full spike sorting pipeline on specified data. @@ -82,6 +82,12 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None, not need to specify this. save_extra_vars : bool; default=False. If True, save tF and Wall to disk after sorting. + clear_cache : bool; default=False. + If True, force pytorch to free up memory reserved for its cache in + between memory-intensive operations. + Note that setting `clear_cache=True` is NOT recommended unless you + encounter GPU out-of-memory errors, since this can result in slower + sorting. save_preprocessed_copy : bool; default=False. If True, save a pre-processed copy of the data (including drift correction) to `temp_wh.dat` in the results directory and format Phy @@ -150,6 +156,8 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None, try: logger.info(f"Kilosort version {kilosort.__version__}") logger.info(f"Sorting {filename}") + if clear_cache: + logger.info('clear_cache=True') logger.info('-'*40) if data_dtype is None: @@ -189,7 +197,6 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None, print_ops = pprint.pformat(ops_copy, indent=4, sort_dicts=False) logger.debug(f"Initial ops:\n{print_ops}\n") - # Set preprocessing and drift correction parameters ops = compute_preprocessing(ops, device, tic0=tic0, file_object=file_object) np.random.seed(1) @@ -197,7 +204,7 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None, torch.random.manual_seed(1) ops, bfile, st0 = compute_drift_correction( ops, device, tic0=tic0, progress_bar=progress_bar, - file_object=file_object + file_object=file_object, clear_cache=clear_cache, ) # Check scale of data for log file @@ -208,14 +215,20 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None, io.save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile) # Sort spikes and save results - st,tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0, - progress_bar=progress_bar) - clu, Wall = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0, - progress_bar=progress_bar) + st,tF, _, _ = detect_spikes( + ops, device, bfile, tic0=tic0, progress_bar=progress_bar, + clear_cache=clear_cache + ) + clu, Wall = cluster_spikes( + st, tF, ops, device, bfile, tic0=tic0, progress_bar=progress_bar, + clear_cache=clear_cache + ) ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \ - save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, - save_extra_vars=save_extra_vars, - save_preprocessed_copy=save_preprocessed_copy) + save_sorting( + ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, + save_extra_vars=save_extra_vars, + save_preprocessed_copy=save_preprocessed_copy + ) except: # This makes sure the full traceback is written to log file. logger.exception('Encountered error in `run_kilosort`:') @@ -456,7 +469,7 @@ def compute_preprocessing(ops, device, tic0=np.nan, file_object=None): def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None, - file_object=None): + file_object=None, clear_cache=False): """Compute drift correction parameters and save them to `ops`. Parameters @@ -504,7 +517,8 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None, file_object=file_object ) - ops, st = datashift.run(ops, bfile, device=device, progress_bar=progress_bar) + ops, st = datashift.run(ops, bfile, device=device, progress_bar=progress_bar, + clear_cache=clear_cache) bfile.close() logger.info(f'drift computed in {time.time()-tic : .2f}s; ' + f'total {time.time()-tic0 : .2f}s') @@ -526,7 +540,8 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None, return ops, bfile, st -def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None): +def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None, + clear_cache=False): """Detect spikes via template deconvolution. Parameters @@ -563,7 +578,10 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None): logger.info(' ') logger.info(f'Extracting spikes using templates') logger.info('-'*40) - st0, tF, ops = spikedetect.run(ops, bfile, device=device, progress_bar=progress_bar) + st0, tF, ops = spikedetect.run( + ops, bfile, device=device, progress_bar=progress_bar, + clear_cache=clear_cache + ) tF = torch.from_numpy(tF) logger.info(f'{len(st0)} spikes extracted in {time.time()-tic : .2f}s; ' + f'total {time.time()-tic0 : .2f}s') @@ -576,8 +594,10 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None): logger.info(' ') logger.info('First clustering') logger.info('-'*40) - clu, Wall = clustering_qr.run(ops, st0, tF, mode='spikes', device=device, - progress_bar=progress_bar) + clu, Wall = clustering_qr.run( + ops, st0, tF, mode='spikes', device=device, progress_bar=progress_bar, + clear_cache=clear_cache + ) Wall3 = template_matching.postprocess_templates(Wall, ops, clu, st0, device=device) logger.info(f'{clu.max()+1} clusters found, in {time.time()-tic : .2f}s; ' + f'total {time.time()-tic0 : .2f}s') @@ -600,7 +620,8 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None): return st, tF, Wall, clu -def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None): +def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None, + clear_cache=False): """Cluster spikes using graph-based methods. Parameters @@ -636,8 +657,10 @@ def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None): logger.info(' ') logger.info('Final clustering') logger.info('-'*40) - clu, Wall = clustering_qr.run(ops, st, tF, mode = 'template', device=device, - progress_bar=progress_bar) + clu, Wall = clustering_qr.run( + ops, st, tF, mode = 'template', device=device, progress_bar=progress_bar, + clear_cache=clear_cache + ) logger.info(f'{clu.max()+1} clusters found, in {time.time()-tic : .2f}s; ' + f'total {time.time()-tic0 : .2f}s') logger.debug(f'clu shape: {clu.shape}') diff --git a/kilosort/spikedetect.py b/kilosort/spikedetect.py index 6fa94053..2b6e1416 100644 --- a/kilosort/spikedetect.py +++ b/kilosort/spikedetect.py @@ -194,7 +194,8 @@ def yweighted(yc, iC, adist, xy, device=torch.device('cuda')): yct = (cF0 * yy[:,xy[:,0]]).sum(0) return yct -def run(ops, bfile, device=torch.device('cuda'), progress_bar=None): +def run(ops, bfile, device=torch.device('cuda'), progress_bar=None, + clear_cache=False): sig = ops['settings']['min_template_size'] nsizes = ops['settings']['template_sizes']