Skip to content

Commit

Permalink
requested changes, output setup refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Lilferrit committed Sep 3, 2024
1 parent 3d91f81 commit 2d6dd00
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 90 deletions.
111 changes: 54 additions & 57 deletions casanovo/casanovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import urllib.parse
import warnings
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional, Tuple, List

warnings.formatwarning = lambda message, category, *args, **kwargs: (
f"{category.__name__}: {message}"
Expand Down Expand Up @@ -168,24 +168,25 @@ def sequence(
to sequence peptides. If evaluate is set to True PEAK_PATH must be
one or more annotated MGF file.
"""
file_patterns = list()
if output_root is not None and not force_overwrite:
file_patterns = [f"{output_root}.log", f"{output_root}.mztab"]

output, output_dir = _resolve_output(
output_dir, output_root, file_patterns, verbosity
output_path, output_root = _setup_output(
output_dir, output_root, force_overwrite, verbosity
)
config, model = setup_model(model, config, output, False)
utils.check_dir_file_exists(output_path, f"{output_root}.mztab")
config, model = setup_model(model, config, output_dir, output_root, False)
start_time = time.time()
with ModelRunner(config, model, output_root, output_dir, False) as runner:
with ModelRunner(config, model, output_path, output_root, False) as runner:
logger.info(
"Sequencing %speptides from:",
"and evaluating " if evaluate else "",
)
for peak_file in peak_path:
logger.info(" %s", peak_file)

runner.predict(peak_path, output, evaluate=evaluate)
runner.predict(
peak_path,
str((output_path / output_root).with_suffix(".mztab")),
evaluate=evaluate,
)
psms = runner.writer.psms
utils.log_sequencing_report(
psms, start_time=start_time, end_time=time.time()
Expand Down Expand Up @@ -225,17 +226,13 @@ def train(
TRAIN_PEAK_PATH must be one or more annoated MGF files, such as those
provided by MassIVE-KB, from which to train a new Casnovo model.
"""
file_patterns = list()
if output_root is not None and not force_overwrite:
file_patterns = [f"{output_root}.log"]

output, output_dir = _resolve_output(
output_dir, output_root, file_patterns, verbosity
output_path, output_root = _setup_output(
output_dir, output_root, force_overwrite, verbosity
)
config, model = setup_model(model, config, output, True)
config, model = setup_model(model, config, output_path, output_root, True)
start_time = time.time()
with ModelRunner(
config, model, output_root, output_dir, not force_overwrite
config, model, output_path, output_root, not force_overwrite
) as runner:
logger.info("Training a model from:")
for peak_file in train_peak_path:
Expand Down Expand Up @@ -283,7 +280,7 @@ def configure(output: str) -> None:


def setup_logging(
output: Optional[str],
log_file_path: Path,
verbosity: str,
) -> Path:
"""Set up the logger.
Expand All @@ -292,21 +289,11 @@ def setup_logging(
Parameters
----------
output : Optional[str]
The provided output file name.
log_file_path: Path
The log file path.
verbosity : str
The logging level to use in the console.
Return
------
output : Path
The output file path.
"""
if output is None:
output = f"casanovo_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"

output = Path(output).expanduser().resolve()

logging_levels = {
"debug": logging.DEBUG,
"info": logging.INFO,
Expand All @@ -333,9 +320,7 @@ def setup_logging(
console_handler.setFormatter(console_formatter)
root_logger.addHandler(console_handler)
warnings_logger.addHandler(console_handler)
file_handler = logging.FileHandler(
output.with_suffix(".log"), encoding="utf8"
)
file_handler = logging.FileHandler(log_file_path, encoding="utf8")
file_handler.setFormatter(log_formatter)
root_logger.addHandler(file_handler)
warnings_logger.addHandler(file_handler)
Expand All @@ -352,13 +337,12 @@ def setup_logging(
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)

return output


def setup_model(
model: Optional[str],
config: Optional[str],
output: Optional[Path],
output_dir: Optional[Path | str],
output_root_name: Optional[str],
is_train: bool,
) -> Config:
"""Setup Casanovo for most commands.
Expand Down Expand Up @@ -418,7 +402,8 @@ def setup_model(
logger.info("Casanovo version %s", str(__version__))
logger.debug("model = %s", model)
logger.debug("config = %s", config.file)
logger.debug("output = %s", output)
logger.debug("output directory = %s", output_dir)
logger.debug("output root name = %s", output_root_name)
for key, value in config.items():
logger.debug("%s = %s", str(key), str(value))

Expand Down Expand Up @@ -522,42 +507,54 @@ def _get_model_weights(cache_dir: Path) -> str:
)


def _resolve_output(
def _setup_output(
output_dir: str | None,
output_root: str | None,
file_patterns: list[str],
overwrite: bool,
verbosity: str,
) -> Tuple[Path, str]:
"""
Resolves the output directory and sets up logging.
Set up the output directory, output file root name, and logging.
Parameters:
-----------
output_dir : str | None
The path to the output directory. If `None`, the current working
directory will be used.
The path to the output directory. If `None`, the output directory will
be resolved to the current working directory.
output_root : str | None
The base name for the output files. If `None`, no specific base name is
set, and logging will be configured accordingly to the behavior of
`setup_logging`.
file_patterns : list[str]
A list of file patterns that should be checked within the `output_dir`.
The base name for the output files. If `None` the output root name will
be resolved to casanovo_<current data and time>
overwrite: bool
Whether to overwrite log file if it already exists in the output
directory.
verbosity : str
The verbosity level for logging.
Returns:
--------
Tuple[Path, str]
The output directory and the base name for log and results files (if
applicable).
A tuple containing the resolved output directory and root name for
output files.
"""
output_dir = Path(output_dir) if output_dir is not None else Path.cwd()
output_base_name = (
None if output_root is None else (output_dir / output_root)
)
utils.check_dir(output_dir, file_patterns)
output = setup_logging(output_base_name, verbosity)
return output, output_dir
if output_root is None:
output_root = (
f"casanovo_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
)

if output_dir is None:
output_path = Path.cwd()
else:
output_path = Path(output_dir)
if not output_path.is_dir():
raise FileNotFoundError(
f"Target output directory {output_dir} does not exists."
)

if not overwrite:
utils.check_dir_file_exists(output_path, f"{output_root}.log")

setup_logging((output_path / output_root).with_suffix(".log"), verbosity)
return output_path, output_root


def _get_weights_from_url(
Expand Down
43 changes: 27 additions & 16 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,21 @@ class ModelRunner:
model_filename : str, optional
The model filename is required for eval and de novo modes,
but not for training a model from scratch.
output_rootname : str, optional
The rootname for all output files (e.g. checkpoints or results)
output_dir : Path | None, optional
The directory where checkpoint files will be saved. If `None` no
checkpoint files will be saved and a warning will be triggered.
output_rootname : str | None, optional
The root name for checkpoint files (e.g., checkpoints or results). If
`None` no base name will be used for checkpoint files.
"""

def __init__(
self,
config: Config,
model_filename: Optional[str] = None,
output_rootname: Optional[str] = None,
output_dir: Optional[str] = None,
overwrite_ckpt_check: bool = True,
output_dir: Optional[Path | None] = None,
output_rootname: Optional[str | None] = None,
overwrite_ckpt_check: Optional[bool] = True,
) -> None:
"""Initialize a ModelRunner"""
self.config = config
Expand All @@ -63,15 +67,19 @@ def __init__(
self.loaders = None
self.writer = None

output_dir = Path.cwd() if output_dir is None else output_dir
prefix = f"{output_rootname}." if output_rootname is not None else ""
curr_filename, best_filename = (
prefix + "{epoch}-{step}",
prefix + "best",
)
if output_dir is None:
self.callbacks = []
warnings.warn(
"Checkpoint directory not set in ModelRunner, "
"no checkpoint files will be saved."
)
return

prefix = f"{output_rootname}." if output_rootname is not None else ""
curr_filename = prefix + "{epoch}-{step}"
best_filename = prefix + "best"
if overwrite_ckpt_check:
utils.check_dir(
utils.check_dir_file_exists(
output_dir,
[
f"{curr_filename.format(epoch='*', step='*')}.ckpt",
Expand Down Expand Up @@ -167,7 +175,10 @@ def log_metrics(self, test_index: AnnotatedSpectrumIndex) -> None:
logger.info("Amino Acid Precision: %.2f%%", 100 * aa_precision)

def predict(
self, peak_path: Iterable[str], output: str, evaluate: bool = False
self,
peak_path: Iterable[str],
results_path: str,
evaluate: bool = False,
) -> None:
"""Predict peptide sequences with a trained Casanovo model.
Expand All @@ -178,8 +189,8 @@ def predict(
----------
peak_path : iterable of str
The path with the MS data files for predicting peptide sequences.
output : str
Where should the output be saved?
results_path : str
Sequencing results file path
evaluate: bool
whether to run model evaluation in addition to inference
Note: peak_path most point to annotated MS data files when
Expand All @@ -190,7 +201,7 @@ def predict(
-------
self
"""
self.writer = ms_io.MztabWriter(Path(output).with_suffix(".mztab"))
self.writer = ms_io.MztabWriter(results_path)
self.writer.set_metadata(
self.config,
model=str(self.model_filename),
Expand Down
18 changes: 11 additions & 7 deletions casanovo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,26 +256,30 @@ def log_sequencing_report(
)


def check_dir(dir: pathlib.Path, file_patterns: Iterable[str]) -> None:
def check_dir_file_exists(
dir: pathlib.Path, file_patterns: Iterable[str] | str
) -> None:
"""
Check that no file names in dir match any of file_patterns
Parameters
----------
dir : pathlib.Path
The directory to check for matching file names
file_patterns : Iterable[str]
UNIX style wildcard pattern to test file names against
file_patterns : Iterable[str] | str
UNIX style wildcard pattern(s) to test file names against
Raises
------
FileExistsError
If matching file name is found in dir
"""
if isinstance(file_patterns, str):
file_patterns = [file_patterns]

for pattern in file_patterns:
matches = list(dir.glob(pattern))
if len(matches) > 0:
if next(dir.glob(pattern), None) is not None:
raise FileExistsError(
f"File {matches[0].name} already exists in {dir} "
"and can not be overwritten."
f"File matching wildcard pattern {pattern} already exist in"
f"{dir} and can not be overwritten."
)
Loading

0 comments on commit 2d6dd00

Please sign in to comment.