Skip to content

Commit

Permalink
[FIX] save models in output dir (#219)
Browse files Browse the repository at this point in the history
* save models in output dir

* improve command line help

* change model_name to model

* ignore link
  • Loading branch information
Remi-Gau authored Aug 5, 2024
1 parent 8cf6133 commit 11943d1
Show file tree
Hide file tree
Showing 12 changed files with 50 additions and 42 deletions.
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ jobs:
--participant_label 302 307 \
--space MNI152NLin2009cAsym \
--reset_database \
--model 1to5 \
-vv
- run:
name: rerun prepare - fast as output already exists
Expand Down
17 changes: 9 additions & 8 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,19 @@ clean-models: ## remove pretrained models
rm -fr models/

models:
bidsmreye_model --model_name 1to6
bidsmreye_model --model 1to6
models/dataset1_guided_fixations.h5:
bidsmreye_model
models/dataset2_pursuit.h5:
bidsmreye_model --model_name 2_pursuit
bidsmreye_model --model 2_pursuit
models/dataset3_openclosed.h5:
bidsmreye_model --model_name 3_openclosed
bidsmreye_model --model 3_openclosed
models/dataset3_pursuit.h5:
bidsmreye_model --model_name 3_pursuit
bidsmreye_model --model 3_pursuit
models/dataset4_pursuit.h5:
bidsmreye_model --model_name 4_pursuit
bidsmreye_model --model 4_pursuit
models/dataset5_free_viewing.h5:
bidsmreye_model --model_name 5_free_viewing
bidsmreye_model --model 5_free_viewing

## DOC
.PHONY: docs docs/source/FAQ.md
Expand Down Expand Up @@ -137,6 +137,7 @@ generalize: ## demo: predicts labels of MOAE dataset
$$PWD/outputs/moae_fmriprep/derivatives \
participant \
generalize \
--model 1_guided_fixations \
-vv


Expand Down Expand Up @@ -234,7 +235,7 @@ docker_build:
docker_build_no_cache:
docker build --tag cpplab/bidsmreye:unstable --no-cache --file Dockerfile .

docker_demo: docker_build clean-demo
docker_demo: clean-demo
make docker_prepare_data
make docker_generalize

Expand All @@ -257,7 +258,7 @@ docker_generalize:
/home/neuro/data/ \
/home/neuro/outputs/ \
participant \
generalize
generalize --model 1_guided_fixations

docker_ds002799: get_ds002799
# datalad unlock $$PWD/tests/data/ds002799/derivatives/fmriprep/sub-30[27]/ses-*/func/*run-*preproc*bold*
Expand Down
2 changes: 1 addition & 1 deletion bidsmreye/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,4 @@ def cli_download(argv: Any = sys.argv) -> None:
parser = download_parser(formatter_class=RichHelpFormatter)
args = parser.parse_args(argv[1:])

download(model_name=args.model_name, output_dir=args.output_dir)
download(model=args.model, output_dir=args.output_dir)
10 changes: 4 additions & 6 deletions bidsmreye/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def common_parser(formatter_class: type[HelpFormatter] = HelpFormatter) -> Argum
# TODO make it possible to pass path to a model ?
generalize_parser.add_argument(
"--model",
help="model to use",
help=f"Model to use. Default model: {default_model()}.",
choices=available_models(),
default=default_model(),
)
Expand All @@ -200,7 +200,7 @@ def common_parser(formatter_class: type[HelpFormatter] = HelpFormatter) -> Argum
# TODO make it possible to pass path to a model ?
all_parser.add_argument(
"--model",
help="model to use",
help=f"Model to use. Default model: {default_model()}.",
choices=available_models(),
default=default_model(),
)
Expand Down Expand Up @@ -228,10 +228,8 @@ def download_parser(
formatter_class=formatter_class,
)
parser.add_argument(
"--model_name",
help="""
Model to download.
""",
"--model",
help=f"Model to download. Default model: {default_model()}.",
choices=available_models(),
default=default_model(),
)
Expand Down
7 changes: 6 additions & 1 deletion bidsmreye/bidsmreye.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,12 @@ def bidsmreye(
if action in {"all", "generalize"} and isinstance(cfg.model_weights_file, str):
from bidsmreye.download import download

cfg.model_weights_file = download(cfg.model_weights_file)
model_output_dir = cfg.output_dir / "models"
model_output_dir.mkdir(exist_ok=True, parents=True)

cfg.model_weights_file = download(
cfg.model_weights_file, output_dir=model_output_dir
)

dispatch(analysis_level=analysis_level, action=action, cfg=cfg)

Expand Down
24 changes: 12 additions & 12 deletions bidsmreye/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,26 @@


def download(
model_name: str | Path | None = None, output_dir: Path | str | None = None
model: str | Path | None = None, output_dir: Path | str | None = None
) -> Path | None:
"""Download the models from OSF.
:param model_name: Model to download. defaults to None
:type model_name: str, optional
:param model: Model to download. defaults to None
:type model: str, optional
:param output_dir: Path where to save the model. Defaults to None.
:type output_dir: Path, optional
:return: Path to the downloaded model.
:rtype: Path
"""
if not model_name:
model_name = default_model()
if isinstance(model_name, Path):
assert model_name.is_file()
return model_name.absolute()
if model_name not in available_models():
warnings.warn(f"{model_name} is not a valid model name.", stacklevel=3)
if not model:
model = default_model()
if isinstance(model, Path):
assert model.is_file()
return model.absolute()
if model not in available_models():
warnings.warn(f"{model} is not a valid model name.", stacklevel=3)
return None

if output_dir is None:
Expand All @@ -52,10 +52,10 @@ def download(
with resources.as_file(source) as registry_file:
POOCH.load_registry(registry_file)

output_file = output_dir / f"dataset_{model_name}"
output_file = output_dir / f"dataset_{model}"

if not output_file.is_file():
file_idx = available_models().index(model_name)
file_idx = available_models().index(model)
filename = f"dataset_{available_models()[file_idx]}.h5"
output_file = POOCH.fetch(filename, progressbar=True)
if isinstance(output_file, str):
Expand Down
18 changes: 9 additions & 9 deletions bidsmreye/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@

def methods(
output_dir: str | Path | None = None,
model_name: str | None = None,
model: str | None = None,
qc_only: bool = False,
) -> Path:
"""Write method section.
:param output_dir: Defaults to Path(".")
:type output_dir: Union[str, Path], optional
:param model_name: Defaults to None.
:type model_name: str, optional
:param model: Defaults to None.
:type model: str, optional
:return: Output file name.
:rtype: Path
Expand All @@ -41,26 +41,26 @@ def methods(
bib_file = str(Path(__file__).parent / "templates" / "CITATION.bib")
shutil.copy(bib_file, output_dir)

if not model_name:
model_name = default_model()
if not model:
model = default_model()

is_known_models = False
is_default_model = False
if model_name in available_models():
if model in available_models():
is_known_models = True
if model_name == default_model():
if model == default_model():
is_default_model = True

if not is_known_models:
warnings.warn(f"{model_name} is not a known model name.", stacklevel=3)
warnings.warn(f"{model} is not a known model name.", stacklevel=3)

template_file = str(Path(__file__).parent / "templates" / "CITATION.mustache")
with open(template_file) as template:
output = chevron.render(
template=template,
data={
"version": __version__,
"model_name": model_name,
"model": model,
"is_default_model": is_default_model,
"is_known_models": is_known_models,
"qc_only": qc_only,
Expand Down
2 changes: 1 addition & 1 deletion bidsmreye/templates/CITATION.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ across voxels (spatial normalization).
Voxels time series were used as inputs for generalization decoding
using a
{{#is_known_models}}
pre-trained model {{ model_name }} from deepMReye from [OSF](https://osf.io/23t5v).
pre-trained model {{ model }} from deepMReye from [OSF](https://osf.io/23t5v).
{{#is_default_model}}
This model was trained on the following datasets:
guided fixations (@alexander_open_2017),
Expand Down
3 changes: 3 additions & 0 deletions mlc_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
"ignorePatterns": [
{
"pattern": "^https://doi.org\/"
},
{
"pattern": "^./faq.md#.*"
}
],
"timeout": "20s",
Expand Down
4 changes: 2 additions & 2 deletions tests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def test_download(tmp_path):
download(model_name="1_guided_fixations", output_dir=str(tmp_path))
download(model="1_guided_fixations", output_dir=str(tmp_path))

assert tmp_path.is_dir()
assert (tmp_path / "dataset_1_guided_fixations.h5").is_file()
Expand All @@ -29,4 +29,4 @@ def test_download_basic():

def test_download_unknown_model():
with pytest.warns(UserWarning):
download(model_name="foo")
download(model="foo")
2 changes: 1 addition & 1 deletion tests/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_methods(tmp_path):


def test_methods_calibration_data(tmp_path):
output_file = methods(output_dir=tmp_path, model_name="calibration_data")
output_file = methods(output_dir=tmp_path, model="calibration_data")
assert output_file.is_file()


Expand Down
2 changes: 1 addition & 1 deletion tests/test_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_download_parser():

args, _ = parser.parse_known_args(
[
"--model_name",
"--model",
"1_guided_fixations",
"--output_dir",
"/home/bob/models",
Expand Down

0 comments on commit 11943d1

Please sign in to comment.