diff --git a/casanovo/casanovo.py b/casanovo/casanovo.py index f9c37f08..d146d5f5 100644 --- a/casanovo/casanovo.py +++ b/casanovo/casanovo.py @@ -12,7 +12,7 @@ import urllib.parse import warnings from pathlib import Path -from typing import Optional, Tuple +from typing import Optional, Tuple, List warnings.formatwarning = lambda message, category, *args, **kwargs: ( f"{category.__name__}: {message}" @@ -168,16 +168,13 @@ def sequence( to sequence peptides. If evaluate is set to True PEAK_PATH must be one or more annotated MGF file. """ - file_patterns = list() - if output_root is not None and not force_overwrite: - file_patterns = [f"{output_root}.log", f"{output_root}.mztab"] - - output, output_dir = _resolve_output( - output_dir, output_root, file_patterns, verbosity + output_path, output_root = _setup_output( + output_dir, output_root, force_overwrite, verbosity ) - config, model = setup_model(model, config, output, False) + utils.check_dir_file_exists(output_path, f"{output_root}.mztab") + config, model = setup_model(model, config, output_dir, output_root, False) start_time = time.time() - with ModelRunner(config, model, output_root, output_dir, False) as runner: + with ModelRunner(config, model, output_path, output_root, False) as runner: logger.info( "Sequencing %speptides from:", "and evaluating " if evaluate else "", @@ -185,7 +182,11 @@ def sequence( for peak_file in peak_path: logger.info(" %s", peak_file) - runner.predict(peak_path, output, evaluate=evaluate) + runner.predict( + peak_path, + str((output_path / output_root).with_suffix(".mztab")), + evaluate=evaluate, + ) psms = runner.writer.psms utils.log_sequencing_report( psms, start_time=start_time, end_time=time.time() @@ -225,17 +226,13 @@ def train( TRAIN_PEAK_PATH must be one or more annoated MGF files, such as those provided by MassIVE-KB, from which to train a new Casnovo model. """ - file_patterns = list() - if output_root is not None and not force_overwrite: - file_patterns = [f"{output_root}.log"] - - output, output_dir = _resolve_output( - output_dir, output_root, file_patterns, verbosity + output_path, output_root = _setup_output( + output_dir, output_root, force_overwrite, verbosity ) - config, model = setup_model(model, config, output, True) + config, model = setup_model(model, config, output_path, output_root, True) start_time = time.time() with ModelRunner( - config, model, output_root, output_dir, not force_overwrite + config, model, output_path, output_root, not force_overwrite ) as runner: logger.info("Training a model from:") for peak_file in train_peak_path: @@ -283,7 +280,7 @@ def configure(output: str) -> None: def setup_logging( - output: Optional[str], + log_file_path: Path, verbosity: str, ) -> Path: """Set up the logger. @@ -292,21 +289,11 @@ def setup_logging( Parameters ---------- - output : Optional[str] - The provided output file name. + log_file_path: Path + The log file path. verbosity : str The logging level to use in the console. - - Return - ------ - output : Path - The output file path. """ - if output is None: - output = f"casanovo_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" - - output = Path(output).expanduser().resolve() - logging_levels = { "debug": logging.DEBUG, "info": logging.INFO, @@ -333,9 +320,7 @@ def setup_logging( console_handler.setFormatter(console_formatter) root_logger.addHandler(console_handler) warnings_logger.addHandler(console_handler) - file_handler = logging.FileHandler( - output.with_suffix(".log"), encoding="utf8" - ) + file_handler = logging.FileHandler(log_file_path, encoding="utf8") file_handler.setFormatter(log_formatter) root_logger.addHandler(file_handler) warnings_logger.addHandler(file_handler) @@ -352,13 +337,12 @@ def setup_logging( logging.getLogger("torch").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING) - return output - def setup_model( model: Optional[str], config: Optional[str], - output: Optional[Path], + output_dir: Optional[Path | str], + output_root_name: Optional[str], is_train: bool, ) -> Config: """Setup Casanovo for most commands. @@ -418,7 +402,8 @@ def setup_model( logger.info("Casanovo version %s", str(__version__)) logger.debug("model = %s", model) logger.debug("config = %s", config.file) - logger.debug("output = %s", output) + logger.debug("output directory = %s", output_dir) + logger.debug("output root name = %s", output_root_name) for key, value in config.items(): logger.debug("%s = %s", str(key), str(value)) @@ -522,42 +507,54 @@ def _get_model_weights(cache_dir: Path) -> str: ) -def _resolve_output( +def _setup_output( output_dir: str | None, output_root: str | None, - file_patterns: list[str], + overwrite: bool, verbosity: str, ) -> Tuple[Path, str]: """ - Resolves the output directory and sets up logging. + Set up the output directory, output file root name, and logging. Parameters: ----------- output_dir : str | None - The path to the output directory. If `None`, the current working - directory will be used. + The path to the output directory. If `None`, the output directory will + be resolved to the current working directory. output_root : str | None - The base name for the output files. If `None`, no specific base name is - set, and logging will be configured accordingly to the behavior of - `setup_logging`. - file_patterns : list[str] - A list of file patterns that should be checked within the `output_dir`. + The base name for the output files. If `None` the output root name will + be resolved to casanovo_ + overwrite: bool + Whether to overwrite log file if it already exists in the output + directory. verbosity : str The verbosity level for logging. Returns: -------- Tuple[Path, str] - The output directory and the base name for log and results files (if - applicable). + A tuple containing the resolved output directory and root name for + output files. """ - output_dir = Path(output_dir) if output_dir is not None else Path.cwd() - output_base_name = ( - None if output_root is None else (output_dir / output_root) - ) - utils.check_dir(output_dir, file_patterns) - output = setup_logging(output_base_name, verbosity) - return output, output_dir + if output_root is None: + output_root = ( + f"casanovo_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" + ) + + if output_dir is None: + output_path = Path.cwd() + else: + output_path = Path(output_dir) + if not output_path.is_dir(): + raise FileNotFoundError( + f"Target output directory {output_dir} does not exists." + ) + + if not overwrite: + utils.check_dir_file_exists(output_path, f"{output_root}.log") + + setup_logging((output_path / output_root).with_suffix(".log"), verbosity) + return output_path, output_root def _get_weights_from_url( diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index 5c8833de..a7302565 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -40,17 +40,21 @@ class ModelRunner: model_filename : str, optional The model filename is required for eval and de novo modes, but not for training a model from scratch. - output_rootname : str, optional - The rootname for all output files (e.g. checkpoints or results) + output_dir : Path | None, optional + The directory where checkpoint files will be saved. If `None` no + checkpoint files will be saved and a warning will be triggered. + output_rootname : str | None, optional + The root name for checkpoint files (e.g., checkpoints or results). If + `None` no base name will be used for checkpoint files. """ def __init__( self, config: Config, model_filename: Optional[str] = None, - output_rootname: Optional[str] = None, - output_dir: Optional[str] = None, - overwrite_ckpt_check: bool = True, + output_dir: Optional[Path | None] = None, + output_rootname: Optional[str | None] = None, + overwrite_ckpt_check: Optional[bool] = True, ) -> None: """Initialize a ModelRunner""" self.config = config @@ -63,15 +67,19 @@ def __init__( self.loaders = None self.writer = None - output_dir = Path.cwd() if output_dir is None else output_dir - prefix = f"{output_rootname}." if output_rootname is not None else "" - curr_filename, best_filename = ( - prefix + "{epoch}-{step}", - prefix + "best", - ) + if output_dir is None: + self.callbacks = [] + warnings.warn( + "Checkpoint directory not set in ModelRunner, " + "no checkpoint files will be saved." + ) + return + prefix = f"{output_rootname}." if output_rootname is not None else "" + curr_filename = prefix + "{epoch}-{step}" + best_filename = prefix + "best" if overwrite_ckpt_check: - utils.check_dir( + utils.check_dir_file_exists( output_dir, [ f"{curr_filename.format(epoch='*', step='*')}.ckpt", @@ -167,7 +175,10 @@ def log_metrics(self, test_index: AnnotatedSpectrumIndex) -> None: logger.info("Amino Acid Precision: %.2f%%", 100 * aa_precision) def predict( - self, peak_path: Iterable[str], output: str, evaluate: bool = False + self, + peak_path: Iterable[str], + results_path: str, + evaluate: bool = False, ) -> None: """Predict peptide sequences with a trained Casanovo model. @@ -178,8 +189,8 @@ def predict( ---------- peak_path : iterable of str The path with the MS data files for predicting peptide sequences. - output : str - Where should the output be saved? + results_path : str + Sequencing results file path evaluate: bool whether to run model evaluation in addition to inference Note: peak_path most point to annotated MS data files when @@ -190,7 +201,7 @@ def predict( ------- self """ - self.writer = ms_io.MztabWriter(Path(output).with_suffix(".mztab")) + self.writer = ms_io.MztabWriter(results_path) self.writer.set_metadata( self.config, model=str(self.model_filename), diff --git a/casanovo/utils.py b/casanovo/utils.py index fb0c3327..1161b5eb 100644 --- a/casanovo/utils.py +++ b/casanovo/utils.py @@ -256,7 +256,9 @@ def log_sequencing_report( ) -def check_dir(dir: pathlib.Path, file_patterns: Iterable[str]) -> None: +def check_dir_file_exists( + dir: pathlib.Path, file_patterns: Iterable[str] | str +) -> None: """ Check that no file names in dir match any of file_patterns @@ -264,18 +266,20 @@ def check_dir(dir: pathlib.Path, file_patterns: Iterable[str]) -> None: ---------- dir : pathlib.Path The directory to check for matching file names - file_patterns : Iterable[str] - UNIX style wildcard pattern to test file names against + file_patterns : Iterable[str] | str + UNIX style wildcard pattern(s) to test file names against Raises ------ FileExistsError If matching file name is found in dir """ + if isinstance(file_patterns, str): + file_patterns = [file_patterns] + for pattern in file_patterns: - matches = list(dir.glob(pattern)) - if len(matches) > 0: + if next(dir.glob(pattern), None) is not None: raise FileExistsError( - f"File {matches[0].name} already exists in {dir} " - "and can not be overwritten." + f"File matching wildcard pattern {pattern} already exist in" + f"{dir} and can not be overwritten." ) diff --git a/tests/unit_tests/test_unit.py b/tests/unit_tests/test_unit.py index 8cdd291d..f221d502 100644 --- a/tests/unit_tests/test_unit.py +++ b/tests/unit_tests/test_unit.py @@ -169,14 +169,14 @@ def test_setup_model(monkeypatch): filename = pathlib.Path(tmp_dir) / "casanovo_massivekb_v3_0_0.ckpt" assert not filename.is_file() - _, result_path = casanovo.setup_model(None, None, None, False) + _, result_path = casanovo.setup_model(None, None, None, None, False) assert result_path.resolve() == filename.resolve() assert filename.is_file() assert mock_get.request_counter == 1 os.remove(result_path) assert not filename.is_file() - _, result = casanovo.setup_model(None, None, None, True) + _, result = casanovo.setup_model(None, None, None, None, True) assert result is None assert not filename.is_file() assert mock_get.request_counter == 1 @@ -195,14 +195,18 @@ def test_setup_model(monkeypatch): cache_file_path = cache_file_dir / cache_file_name assert not cache_file_path.is_file() - _, result_path = casanovo.setup_model(file_url, None, None, False) + _, result_path = casanovo.setup_model( + file_url, None, None, None, False + ) assert cache_file_path.is_file() assert result_path.resolve() == cache_file_path.resolve() assert mock_get.request_counter == 2 os.remove(result_path) assert not cache_file_path.is_file() - _, result_path = casanovo.setup_model(file_url, None, None, False) + _, result_path = casanovo.setup_model( + file_url, None, None, None, False + ) assert cache_file_path.is_file() assert result_path.resolve() == cache_file_path.resolve() assert mock_get.request_counter == 3 @@ -217,11 +221,15 @@ def test_setup_model(monkeypatch): mnk.setattr(requests, "get", mock_get) temp_file_path = temp_file.name - _, result = casanovo.setup_model(temp_file_path, None, None, False) + _, result = casanovo.setup_model( + temp_file_path, None, None, None, False + ) assert mock_get.request_counter == 3 assert result == temp_file_path - _, result = casanovo.setup_model(temp_file_path, None, None, True) + _, result = casanovo.setup_model( + temp_file_path, None, None, None, True + ) assert mock_get.request_counter == 3 assert result == temp_file_path @@ -233,12 +241,12 @@ def test_setup_model(monkeypatch): mnk.setattr(requests, "get", mock_get) with pytest.raises(ValueError): - casanovo.setup_model("FooBar", None, None, False) + casanovo.setup_model("FooBar", None, None, None, False) assert mock_get.request_counter == 3 with pytest.raises(ValueError): - casanovo.setup_model("FooBar", None, None, False) + casanovo.setup_model("FooBar", None, None, None, False) assert mock_get.request_counter == 3 @@ -929,6 +937,10 @@ def test_check_dir(tmp_path): dne_pattern = "dne-*.ckpt" with pytest.raises(FileExistsError): - utils.check_dir(tmp_path, [exists_pattern, dne_pattern]) + utils.check_dir_file_exists(tmp_path, [exists_pattern, dne_pattern]) + + with pytest.raises(FileExistsError): + utils.check_dir_file_exists(tmp_path, exists_pattern) - utils.check_dir(tmp_path, [dne_pattern]) + utils.check_dir_file_exists(tmp_path, [dne_pattern]) + utils.check_dir_file_exists(tmp_path, dne_pattern)