diff --git a/casanovo/casanovo.py b/casanovo/casanovo.py index d146d5f5..95b3a183 100644 --- a/casanovo/casanovo.py +++ b/casanovo/casanovo.py @@ -523,7 +523,7 @@ def _setup_output( be resolved to the current working directory. output_root : str | None The base name for the output files. If `None` the output root name will - be resolved to casanovo_ + be resolved to casanovo_ overwrite: bool Whether to overwrite log file if it already exists in the output directory. @@ -544,10 +544,12 @@ def _setup_output( if output_dir is None: output_path = Path.cwd() else: - output_path = Path(output_dir) + output_path = Path(output_dir).expanduser().resolve() if not output_path.is_dir(): - raise FileNotFoundError( - f"Target output directory {output_dir} does not exists." + output_path.mkdir(parents=True) + warnings.warn( + f"Target output directory {output_dir} does not exists, " + "so it will be created." ) if not overwrite: diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index adc4c8cc..73d1da77 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -72,7 +72,7 @@ def __init__( if output_dir is None: self.callbacks = [] - warnings.warn( + logger.warning( "Checkpoint directory not set in ModelRunner, " "no checkpoint files will be saved." ) diff --git a/tests/unit_tests/test_unit.py b/tests/unit_tests/test_unit.py index f221d502..9ad8490d 100644 --- a/tests/unit_tests/test_unit.py +++ b/tests/unit_tests/test_unit.py @@ -7,6 +7,7 @@ import os import pathlib import platform +import re import requests import shutil import tempfile @@ -944,3 +945,22 @@ def test_check_dir(tmp_path): utils.check_dir_file_exists(tmp_path, [dne_pattern]) utils.check_dir_file_exists(tmp_path, dne_pattern) + + +def test_setup_output(tmp_path, monkeypatch): + with monkeypatch.context() as mnk: + mnk.setattr(pathlib.Path, "cwd", lambda: tmp_path) + output_path, output_root = casanovo._setup_output( + None, None, False, "info" + ) + assert output_path.resolve() == tmp_path.resolve() + assert re.fullmatch(r"casanovo_\d+", output_root) is not None + + target_path = tmp_path / "foo" + with pytest.warns(UserWarning): + output_path, output_root = casanovo._setup_output( + str(target_path), "bar", False, "info" + ) + + assert output_path.resolve() == target_path.resolve() + assert output_root == "bar"