From 11943d1b3b2ca6817b9cfd3aeab80f0b003adc98 Mon Sep 17 00:00:00 2001 From: Remi Gau Date: Mon, 5 Aug 2024 18:30:49 +0200 Subject: [PATCH] [FIX] save models in output dir (#219) * save models in output dir * improve command line help * change model_name to model * ignore link --- .circleci/config.yml | 1 + Makefile | 17 +++++++++-------- bidsmreye/_cli.py | 2 +- bidsmreye/_parsers.py | 10 ++++------ bidsmreye/bidsmreye.py | 7 ++++++- bidsmreye/download.py | 24 ++++++++++++------------ bidsmreye/methods.py | 18 +++++++++--------- bidsmreye/templates/CITATION.mustache | 2 +- mlc_config.json | 3 +++ tests/test_download.py | 4 ++-- tests/test_methods.py | 2 +- tests/test_parsers.py | 2 +- 12 files changed, 50 insertions(+), 42 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 862e1f7..15c86a7 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -99,6 +99,7 @@ jobs: --participant_label 302 307 \ --space MNI152NLin2009cAsym \ --reset_database \ + --model 1to5 \ -vv - run: name: rerun prepare - fast as output already exists diff --git a/Makefile b/Makefile index aaf4387..b355071 100644 --- a/Makefile +++ b/Makefile @@ -70,19 +70,19 @@ clean-models: ## remove pretrained models rm -fr models/ models: - bidsmreye_model --model_name 1to6 + bidsmreye_model --model 1to6 models/dataset1_guided_fixations.h5: bidsmreye_model models/dataset2_pursuit.h5: - bidsmreye_model --model_name 2_pursuit + bidsmreye_model --model 2_pursuit models/dataset3_openclosed.h5: - bidsmreye_model --model_name 3_openclosed + bidsmreye_model --model 3_openclosed models/dataset3_pursuit.h5: - bidsmreye_model --model_name 3_pursuit + bidsmreye_model --model 3_pursuit models/dataset4_pursuit.h5: - bidsmreye_model --model_name 4_pursuit + bidsmreye_model --model 4_pursuit models/dataset5_free_viewing.h5: - bidsmreye_model --model_name 5_free_viewing + bidsmreye_model --model 5_free_viewing ## DOC .PHONY: docs docs/source/FAQ.md @@ -137,6 +137,7 @@ generalize: ## demo: predicts labels of MOAE dataset $$PWD/outputs/moae_fmriprep/derivatives \ participant \ generalize \ + --model 1_guided_fixations \ -vv @@ -234,7 +235,7 @@ docker_build: docker_build_no_cache: docker build --tag cpplab/bidsmreye:unstable --no-cache --file Dockerfile . -docker_demo: docker_build clean-demo +docker_demo: clean-demo make docker_prepare_data make docker_generalize @@ -257,7 +258,7 @@ docker_generalize: /home/neuro/data/ \ /home/neuro/outputs/ \ participant \ - generalize + generalize --model 1_guided_fixations docker_ds002799: get_ds002799 # datalad unlock $$PWD/tests/data/ds002799/derivatives/fmriprep/sub-30[27]/ses-*/func/*run-*preproc*bold* diff --git a/bidsmreye/_cli.py b/bidsmreye/_cli.py index 54e605f..919a739 100644 --- a/bidsmreye/_cli.py +++ b/bidsmreye/_cli.py @@ -69,4 +69,4 @@ def cli_download(argv: Any = sys.argv) -> None: parser = download_parser(formatter_class=RichHelpFormatter) args = parser.parse_args(argv[1:]) - download(model_name=args.model_name, output_dir=args.output_dir) + download(model=args.model, output_dir=args.output_dir) diff --git a/bidsmreye/_parsers.py b/bidsmreye/_parsers.py index 05b1e61..69dcbdf 100644 --- a/bidsmreye/_parsers.py +++ b/bidsmreye/_parsers.py @@ -177,7 +177,7 @@ def common_parser(formatter_class: type[HelpFormatter] = HelpFormatter) -> Argum # TODO make it possible to pass path to a model ? generalize_parser.add_argument( "--model", - help="model to use", + help=f"Model to use. Default model: {default_model()}.", choices=available_models(), default=default_model(), ) @@ -200,7 +200,7 @@ def common_parser(formatter_class: type[HelpFormatter] = HelpFormatter) -> Argum # TODO make it possible to pass path to a model ? all_parser.add_argument( "--model", - help="model to use", + help=f"Model to use. Default model: {default_model()}.", choices=available_models(), default=default_model(), ) @@ -228,10 +228,8 @@ def download_parser( formatter_class=formatter_class, ) parser.add_argument( - "--model_name", - help=""" -Model to download. - """, + "--model", + help=f"Model to download. Default model: {default_model()}.", choices=available_models(), default=default_model(), ) diff --git a/bidsmreye/bidsmreye.py b/bidsmreye/bidsmreye.py index c82fbc6..94bf70d 100755 --- a/bidsmreye/bidsmreye.py +++ b/bidsmreye/bidsmreye.py @@ -75,7 +75,12 @@ def bidsmreye( if action in {"all", "generalize"} and isinstance(cfg.model_weights_file, str): from bidsmreye.download import download - cfg.model_weights_file = download(cfg.model_weights_file) + model_output_dir = cfg.output_dir / "models" + model_output_dir.mkdir(exist_ok=True, parents=True) + + cfg.model_weights_file = download( + cfg.model_weights_file, output_dir=model_output_dir + ) dispatch(analysis_level=analysis_level, action=action, cfg=cfg) diff --git a/bidsmreye/download.py b/bidsmreye/download.py index 1bf56b6..9f9dd06 100644 --- a/bidsmreye/download.py +++ b/bidsmreye/download.py @@ -16,12 +16,12 @@ def download( - model_name: str | Path | None = None, output_dir: Path | str | None = None + model: str | Path | None = None, output_dir: Path | str | None = None ) -> Path | None: """Download the models from OSF. - :param model_name: Model to download. defaults to None - :type model_name: str, optional + :param model: Model to download. defaults to None + :type model: str, optional :param output_dir: Path where to save the model. Defaults to None. :type output_dir: Path, optional @@ -29,13 +29,13 @@ def download( :return: Path to the downloaded model. :rtype: Path """ - if not model_name: - model_name = default_model() - if isinstance(model_name, Path): - assert model_name.is_file() - return model_name.absolute() - if model_name not in available_models(): - warnings.warn(f"{model_name} is not a valid model name.", stacklevel=3) + if not model: + model = default_model() + if isinstance(model, Path): + assert model.is_file() + return model.absolute() + if model not in available_models(): + warnings.warn(f"{model} is not a valid model name.", stacklevel=3) return None if output_dir is None: @@ -52,10 +52,10 @@ def download( with resources.as_file(source) as registry_file: POOCH.load_registry(registry_file) - output_file = output_dir / f"dataset_{model_name}" + output_file = output_dir / f"dataset_{model}" if not output_file.is_file(): - file_idx = available_models().index(model_name) + file_idx = available_models().index(model) filename = f"dataset_{available_models()[file_idx]}.h5" output_file = POOCH.fetch(filename, progressbar=True) if isinstance(output_file, str): diff --git a/bidsmreye/methods.py b/bidsmreye/methods.py index ba6b47b..49973b3 100644 --- a/bidsmreye/methods.py +++ b/bidsmreye/methods.py @@ -15,7 +15,7 @@ def methods( output_dir: str | Path | None = None, - model_name: str | None = None, + model: str | None = None, qc_only: bool = False, ) -> Path: """Write method section. @@ -23,8 +23,8 @@ def methods( :param output_dir: Defaults to Path(".") :type output_dir: Union[str, Path], optional - :param model_name: Defaults to None. - :type model_name: str, optional + :param model: Defaults to None. + :type model: str, optional :return: Output file name. :rtype: Path @@ -41,18 +41,18 @@ def methods( bib_file = str(Path(__file__).parent / "templates" / "CITATION.bib") shutil.copy(bib_file, output_dir) - if not model_name: - model_name = default_model() + if not model: + model = default_model() is_known_models = False is_default_model = False - if model_name in available_models(): + if model in available_models(): is_known_models = True - if model_name == default_model(): + if model == default_model(): is_default_model = True if not is_known_models: - warnings.warn(f"{model_name} is not a known model name.", stacklevel=3) + warnings.warn(f"{model} is not a known model name.", stacklevel=3) template_file = str(Path(__file__).parent / "templates" / "CITATION.mustache") with open(template_file) as template: @@ -60,7 +60,7 @@ def methods( template=template, data={ "version": __version__, - "model_name": model_name, + "model": model, "is_default_model": is_default_model, "is_known_models": is_known_models, "qc_only": qc_only, diff --git a/bidsmreye/templates/CITATION.mustache b/bidsmreye/templates/CITATION.mustache index 63f1e22..78f46bd 100644 --- a/bidsmreye/templates/CITATION.mustache +++ b/bidsmreye/templates/CITATION.mustache @@ -29,7 +29,7 @@ across voxels (spatial normalization). Voxels time series were used as inputs for generalization decoding using a {{#is_known_models}} -pre-trained model {{ model_name }} from deepMReye from [OSF](https://osf.io/23t5v). +pre-trained model {{ model }} from deepMReye from [OSF](https://osf.io/23t5v). {{#is_default_model}} This model was trained on the following datasets: guided fixations (@alexander_open_2017), diff --git a/mlc_config.json b/mlc_config.json index e07809a..1d29a26 100644 --- a/mlc_config.json +++ b/mlc_config.json @@ -2,6 +2,9 @@ "ignorePatterns": [ { "pattern": "^https://doi.org\/" + }, + { + "pattern": "^./faq.md#.*" } ], "timeout": "20s", diff --git a/tests/test_download.py b/tests/test_download.py index aba4076..40641b9 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -9,7 +9,7 @@ def test_download(tmp_path): - download(model_name="1_guided_fixations", output_dir=str(tmp_path)) + download(model="1_guided_fixations", output_dir=str(tmp_path)) assert tmp_path.is_dir() assert (tmp_path / "dataset_1_guided_fixations.h5").is_file() @@ -29,4 +29,4 @@ def test_download_basic(): def test_download_unknown_model(): with pytest.warns(UserWarning): - download(model_name="foo") + download(model="foo") diff --git a/tests/test_methods.py b/tests/test_methods.py index 2bd91ad..e63fc02 100644 --- a/tests/test_methods.py +++ b/tests/test_methods.py @@ -10,7 +10,7 @@ def test_methods(tmp_path): def test_methods_calibration_data(tmp_path): - output_file = methods(output_dir=tmp_path, model_name="calibration_data") + output_file = methods(output_dir=tmp_path, model="calibration_data") assert output_file.is_file() diff --git a/tests/test_parsers.py b/tests/test_parsers.py index 5677941..d41260e 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -53,7 +53,7 @@ def test_download_parser(): args, _ = parser.parse_known_args( [ - "--model_name", + "--model", "1_guided_fixations", "--output_dir", "/home/bob/models",