Skip to content

Commit

Permalink
Merge pull request #338 from RichieHakim/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
RichieHakim authored Oct 23, 2024
2 parents 1953f69 + d3603e4 commit c77b628
Show file tree
Hide file tree
Showing 15 changed files with 312 additions and 107 deletions.
18 changes: 5 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ computer.
notebook](https://github.com/RichieHakim/ROICaT/blob/main/notebooks/jupyter/tracking/1_tracking_interactive_notebook.ipynb)
- [Google
CoLab](https://githubtocolab.com/RichieHakim/ROICaT/blob/main/notebooks/colab/tracking/1_tracking_interactive_notebook.ipynb)
- [Command line interface script](https://github.com/RichieHakim/ROICaT/blob/main/scripts/demo_run_tracking_pipeline.py)
- [Command line interface script](https://github.com/RichieHakim/ROICaT/blob/main/scripts/run_tracking.sh):
```shell
roicat --pipeline tracking --path_params /path/to/params.yaml --dir_data /folder/with/data/ --dir_save /folder/save/ --prefix_name_save expName --verbose
```

### CLASSIFICATION:
- [Interactive notebook -
Expand Down Expand Up @@ -111,21 +114,10 @@ installation process, please make a [github
issue](https://github.com/RichieHakim/ROICaT/issues) with the error.

### 0. Requirements
- Your segmented data. For example Suite2p output data (stat.npy and ops.npy
files), CaImAn output data (results.h5 files), or any other type of data using
this [custom data importing
notebook](https://github.com/RichieHakim/ROICaT/blob/main/notebooks/jupyter/other/demo_data_importing.ipynb).
- [Anaconda](https://www.anaconda.com/distribution/) or
[Miniconda](https://docs.conda.io/en/latest/miniconda.html).
- If using Windows: [Microsoft C++ Build
Tools](https://visualstudio.microsoft.com/visual-cpp-build-tools/)
- If using linux/unix(MacOS): GCC >= 5.4.0, ideally == 9.2.0. Check with: `gcc
--version`, or just google how to do this on your operating system.
- **Optional:** [CUDA compatible NVIDIA
GPU](https://developer.nvidia.com/cuda-gpus) and
[drivers](https://developer.nvidia.com/cuda-toolkit-archive). Using a GPU can
increase ROICaT speeds ~5-50x, though without it, ROICaT will still run
reasonably quick. GPU support is not available for Macs.
- The below commands should be run in the terminal (Mac/Linux) or Anaconda
Prompt (Windows).

Expand All @@ -149,7 +141,7 @@ install --user 'roicat[all]'` For installing GPU support on Windows, see
**Note on opencv:** The headless version of opencv is installed by default. If
the regular version is already installed, you will need to uninstall it first.

### 3. Clone the repo to get the scripts and notebooks
### 3. Clone the repo to get the notebooks
```
git clone https://github.com/RichieHakim/ROICaT
```
Expand Down

Large diffs are not rendered by default.

14 changes: 1 addition & 13 deletions notebooks/jupyter/other/demo_data_importing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "77c4421b",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -903,18 +903,6 @@
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
},
"varInspector": {
"cols": {
"lenName": 16,
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ jupyter_bokeh==4.0.5
onnx2torch==1.5.15
scikit-image==0.24.0
richfile>=0.4.5
romatch-roicat>=0.1.0
romatch-roicat>=0.1.0
kornia==0.7.3
4 changes: 3 additions & 1 deletion roicat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@
for pkg in __all__:
exec('from . import ' + pkg)

__version__ = '1.4.0'
from .__main__ import run_pipeline

__version__ = '1.4.1'
125 changes: 125 additions & 0 deletions roicat/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from typing import Optional

import argparse
import datetime
from pathlib import Path

from roicat import pipelines, util, helpers

PIPELINES = {
'tracking': pipelines.pipeline_tracking,
}

def add_args(parser: argparse.ArgumentParser):
"""
Adds arguments to the parser.
"""
# Running pipelines
## pipeline type
parser.add_argument('--pipeline', type=str, required=False, help=f"Type of pipeline to run. Options: {', '.join(PIPELINES.keys())}")

## path_params
parser.add_argument('--path_params', type=str, required=False, help='Path to the parameters .yaml file to use')

## dir_outer
parser.add_argument('--dir_data', type=str, required=False, default=None, help='Path to the outer directory containing the dataset')

## dir_save
parser.add_argument('--dir_save', type=str, required=False, default=None, help='Path to the directory to save the results')

## prefix_name_save. Default is the current date and time
parser.add_argument('--prefix_name_save', type=str, required=False, default=None, help='Prefix to append to the saved files. Default is the current date and time')

## verbose
parser.add_argument('--verbose', action='store_true', help='Print verbose output')

# Other arguments
## version
parser.add_argument('--version', action='store_true', help='Print the version number.')

return parser


def run_pipeline(
pipeline_name: str = 'tracking',
path_params: Optional[str] = None,
dir_data: Optional[str] = None,
dir_save: Optional[str] = None,
prefix_name_save: Optional[str] = None,
verbose: bool = False,
):
"""
Call a pipeline with the specified parameters.
"""

# Load in parameters to use
if path_params is not None:
params = helpers.yaml_load(path_params)
else:
print(f"WARNING: No parameters file specified. Using default parameters for pipeline '{pipeline_name}'")
params = {}

# These lines are for safety, to make sure that all params are present and valid
params_defaults = util.get_default_parameters(pipeline=pipeline_name)
params = helpers.prepare_params(params=params, defaults=params_defaults)

# User specified directories
def inplace_update_if_not_none(d, key, value):
if value is not None:
d[key] = value
inplace_update_if_not_none(params['data_loading'], 'dir_outer', dir_data)
inplace_update_if_not_none(params['results_saving'], 'dir_save', dir_save)
inplace_update_if_not_none(params['results_saving'], 'prefix_name_save', prefix_name_save)
inplace_update_if_not_none(params['general'], 'verbose', verbose)

# Run pipeline
results, run_data, params = PIPELINES[pipeline_name](params=params)


def _prepare_path(path, must_exist=False):
"""
Resolve path and check if it exists.
"""
if path is not None:
path = str(Path(path).resolve())
if not Path(path).exists() and must_exist:
raise FileNotFoundError(f'Path not found: {path}')
return path


def main():
"""
Main function.
"""
parser = argparse.ArgumentParser(description='Run a ROICaT pipeline')
parser = add_args(parser)
args = parser.parse_args()

if args.version:
import importlib.metadata
print(f"ROICaT version: {importlib.metadata.version('roicat')}")

## If nothing is specified, print a message
if not any(vars(args).values()):
print("Welcome to ROICaT!\n Use the --help flag to see the available options.")

## If the pipeline argument is specified, call the pipeline function

path_params = _prepare_path(args.path_params, must_exist=True)
dir_data = _prepare_path(args.dir_data, must_exist=True)
dir_save = _prepare_path(args.dir_save, must_exist=False)
prefix_name_save = args.prefix_name_save if args.prefix_name_save is not None else str(datetime.datetime.now().strftime("%Y%m%d_%H%M%S"))

if args.pipeline is not None:
run_pipeline(
pipeline_name=args.pipeline,
path_params=path_params,
dir_data=dir_data,
dir_save=dir_save,
prefix_name_save=prefix_name_save,
verbose=args.verbose,
)


if __name__ == '__main__':
main()
132 changes: 130 additions & 2 deletions roicat/helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Dict, Tuple, Union, Optional, Any, Callable, Iterable, Sequence, Type, Any, MutableMapping

from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import multiprocessing as mp
import os
Expand All @@ -14,16 +16,18 @@
from PIL import ImageTk
import csv
import warnings
from typing import List, Dict, Tuple, Union, Optional, Any, Callable, Iterable, Sequence, Type, Any, MutableMapping
import time
import datetime

import numpy as np
import torch
import torchvision
import scipy.sparse
import sparse
from tqdm import tqdm
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import yaml
import optuna

"""
All of these are from basic_neural_processing_modules
Expand Down Expand Up @@ -672,6 +676,130 @@ def check(
self.num_trial += 1


class OptunaProgressBar:
"""
A customizable progress bar for Optuna's study.optimize().
Args:
n_trials (int, optional):
The number of trials. Exactly one of n_trials or timeout must be
set.
timeout (float, optional):
The maximum time to run in seconds. Exactly one of n_trials or
timeout must be set.
tqdm_kwargs (dict, optional):
Additional keyword arguments to pass to tqdm.
These will override the default values EXCEPT for the following
kwargs, which will defer to the environment variables: \n
* ``disable``
* ``dynamic_ncols``
"""

def __init__(
self,
n_trials: Optional[int] = None,
timeout: Optional[float] = None,
**tqdm_kwargs: Any,
):
if (n_trials is None) and (timeout is None):
raise ValueError("Either n_trials or timeout must be set.")
## Exclusivity between n_trials and timeout
elif (n_trials is not None) and (timeout is not None):
raise ValueError("Only one of n_trials or timeout should be set.")

## Store user-provided values
self._n_trials = n_trials
self._timeout = timeout
self._tqdm_kwargs = self._get_default_kwargs()

## Overwrite default values with user-provided values
self._tqdm_kwargs.update(tqdm_kwargs)

## Initialize params for bar
self.bar = None
self._last_elapsed_seconds = 0.0

## Initialize global variables for time diff
global _time_diff_last_call
_time_diff_last_call = time.time()
global _times_calls
_times_calls = []
global _time_start
_time_start = time.time()

## Initialize progress bar
if self._n_trials is not None:
self.bar = tqdm(
total=self._n_trials,
**self._tqdm_kwargs,
)
elif self._timeout is not None:
total = tqdm.format_interval(self._timeout)
fmt = "{desc} {percentage:3.0f}%|{bar}| {elapsed}/" + total
self.bar = tqdm(total=self._timeout, bar_format=fmt)

def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial):
# Initialize progress bar on first call
if self.bar is None:
# Disable progress bar based on environment variables and context
if self._get_env_variables().get('disable', False):
self._tqdm_kwargs['disable'] = True

# Update description based on mininterval
current_time = time.time()
# Drop current time into _times_calls list
global _times_calls
min_interval = self._tqdm_kwargs.get('mininterval', 0.1)
global _time_diff_last_call
if (current_time - _time_diff_last_call >= min_interval) and all([current_time - t >= min_interval for t in _times_calls]):
_times_calls.append(current_time)
# print(_time_diff_last_call, _times_calls, current_time)
self._update(study)
_time_diff_last_call = current_time
## Remove the first element of the list
_times_calls.pop(0)

def _update(self, study: optuna.study.Study):
# Update progress bar
if self._n_trials is not None:
## Get the current trial number
i_trial = len(study.trials)
self.bar.update(i_trial - self.bar.n)
elif self._timeout is not None:
# Get the total elapsed time for the study (last trial - first trial)
t_start = study.trials[0].datetime_start
t_last = study.trials[-1].datetime_complete
t_last = datetime.datetime.now()
elapsed_seconds = t_last - t_start
self.bar.update(elapsed_seconds.total_seconds() - self.bar.n)

try:
best_value = study.best_value
best_trial = study.best_trial
self.bar.set_description(f"Best trial: {best_trial.number}, value: {best_value:.6g}")
except ValueError:
pass # No trials completed yet

def _get_env_variables(self):
# Load TQDM environment variables
env_vars = {
'disable': bool( os.environ.get('TQDM_DISABLE', '').lower() == 'true'),
'mininterval': float(os.environ.get('TQDM_MININTERVAL', 0.1)),
'maxinterval': float(os.environ.get('TQDM_MAXINTERVAL', 10.0)),
'miniters': int( os.environ.get('TQDM_MINITERS', 1)),
'smoothing': float(os.environ.get('TQDM_SMOOTHING', 0.3)),
}
return env_vars

def _get_default_kwargs(self):
kwargs_default = self._get_env_variables()

# Set default values for better handling in environments
kwargs_default.setdefault('dynamic_ncols', True)
kwargs_default.setdefault('leave', False)
return kwargs_default


######################################################################################################################################
######################################################## FEATURIZATION ###############################################################
######################################################################################################################################
Expand Down
8 changes: 4 additions & 4 deletions roicat/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def pipeline_tracking(params: dict):
deterministic=params['general']['random_seed'] is not None,
)

if params['data_loading']['data_kind'] == 'data_suite2p':
assert params['data_loading']['dir_outer'] is not None, f"params['data_loading']['dir_outer'] must be specified if params['data_loading']['data_kind'] is 'data_suite2p'."
if params['data_loading']['data_kind'] == 'suite2p':
assert params['data_loading']['dir_outer'] is not None, f"params['data_loading']['dir_outer'] must be specified if params['data_loading']['data_kind'] is 'suite2p'."
paths_allStat = helpers.find_paths(
dir_outer=params['data_loading']['dir_outer'],
reMatch='stat.npy',
Expand All @@ -61,12 +61,12 @@ def pipeline_tracking(params: dict):
)[:]
paths_allOps = [str(Path(path).resolve().parent / 'ops.npy') for path in paths_allStat][:]

if len(paths_allStat) == 0:
raise FileNotFoundError(f"No stat.npy files found in '{params['data_loading']['dir_outer']}'")
print(f"Found the following stat.npy files:")
[print(f" {path}") for path in paths_allStat]
print(f"Found the following corresponding ops.npy files:")
[print(f" {path}") for path in paths_allOps]
if len(paths_allStat) == 0:
raise FileNotFoundError(f"No stat.npy files found in '{params['data_loading']['dir_outer']}'")


## Import data
Expand Down
Loading

0 comments on commit c77b628

Please sign in to comment.