diff --git a/.github/workflows/black.yml b/.github/workflows/lint.yml similarity index 83% rename from .github/workflows/black.yml rename to .github/workflows/lint.yml index cec52e2b..fb937494 100644 --- a/.github/workflows/black.yml +++ b/.github/workflows/lint.yml @@ -1,6 +1,14 @@ name: Lint -on: [push, pull_request] +on: + push: + branches: + - main + - dev + pull_request: + branches: + - main + - dev jobs: lint: diff --git a/.github/workflows/screenshots.yml b/.github/workflows/screenshots.yml new file mode 100644 index 00000000..a9bcf896 --- /dev/null +++ b/.github/workflows/screenshots.yml @@ -0,0 +1,33 @@ +name: Screenshots with rich-codex +on: + pull_request: + paths: + - "docs/*.md" + - "casanovo/casanovo.py" + workflow_dispatch: + +jobs: + rich_codex: + runs-on: ubuntu-latest + steps: + - name: Check out the repo + uses: actions/checkout@v4 + with: + ref: ${{ github.head_ref }} + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + + - name: Install your custom tools + run: | + python -m pip install --upgrade pip + pip install . + + - name: Generate terminal images with rich-codex + uses: ewels/rich-codex@v1 + with: + timeout: 10 + commit_changes: "true" + clean_img_paths: docs/images/*.svg diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1d7fe2f7..08001ed5 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -5,16 +5,20 @@ name: tests on: push: - branches: [ main ] + branches: + - main + - dev pull_request: - branches: [ main ] + branches: + - main + - dev jobs: build: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, windows-latest] + os: [ubuntu-latest, windows-latest, macos-latest] steps: - uses: actions/checkout@v2 diff --git a/.gitignore b/.gitignore index 32202470..aa8178a5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ # Test stuff: test_path/ +lightning_logs/ +envs/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/CHANGELOG.md b/CHANGELOG.md index a3258e1c..bbc9284e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,42 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased] +## [4.0.0] - 2023-12-22 + +### Added + +- Checkpoints include model parameters, allowing for mismatches with the provided configuration file. +- `accelerator` parameter controls the accelerator (CPU, GPU, etc) that is used. +- `devices` parameter controls the number of accelerators used. +- `val_check_interval` parameter controls the frequency of both validation epochs and model checkpointing during training. +- `train_label_smoothing` parameter controls the amount of label smoothing applied when calculating the training loss. + +### Changed + +- The CLI has been overhauled to use subcommands. +- Upgraded to Lightning >=2.0. +- Checkpointing is configured to save the top-k models instead of all. +- Log steps rather than epochs as units of progress during training. +- Validation performance metrics are logged (and added to tensorboard) at the validation epoch, and training loss is logged at the end of training epoch, i.e. training and validation metrics are logged asynchronously. +- Irrelevant warning messages on the console output and in the log file are no longer shown. +- Nicely format logged warnings. +- `every_n_train_steps` has been renamed to `val_check_interval` in accordance to the corresponding Pytorch Lightning parameter. +- Training batches are randomly shuffled. +- Upgraded to Torch >=2.1. + +### Removed + +- Remove config option for a custom Pytorch Lightning logger. +- Remove superfluous `custom_encoder` config option. + +### Fixed + +- Casanovo runs on CPU and can pass all tests. +- Correctly refer to input peak files by their full file path. +- Specifying custom residues to retrain Casanovo is now possible. +- Upgrade to depthcharge v0.2.3 to fix sinusoidal encoding and for the `PeptideTransformerDecoder` hotfix. +- Correctly report amino acid precision and recall during validation. + ## [3.5.0] - 2023-08-16 ### Fixed @@ -181,7 +217,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Initial Casanovo version. -[Unreleased]: https://github.com/Noble-Lab/casanovo/compare/v3.5.0...HEAD +[Unreleased]: https://github.com/Noble-Lab/casanovo/compare/v4.0.0...HEAD +[4.0.0]: https://github.com/Noble-Lab/casanovo/compare/v3.5.0...v4.0.0 [3.5.0]: https://github.com/Noble-Lab/casanovo/compare/v3.4.0...v3.5.0 [3.4.0]: https://github.com/Noble-Lab/casanovo/compare/v3.3.0...v3.4.0 [3.3.0]: https://github.com/Noble-Lab/casanovo/compare/v3.2.0...v3.3.0 diff --git a/casanovo/casanovo.py b/casanovo/casanovo.py index 72b02057..0a1c3618 100644 --- a/casanovo/casanovo.py +++ b/casanovo/casanovo.py @@ -7,116 +7,312 @@ import shutil import sys import warnings +from pathlib import Path from typing import Optional, Tuple +warnings.formatwarning = lambda message, category, *args, **kwargs: ( + f"{category.__name__}: {message}" +) warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings( + "ignore", + ".*Consider increasing the value of the `num_workers` argument*", +) +warnings.filterwarnings( + "ignore", + ".*The PyTorch API of nested tensors is in prototype stage*", +) +warnings.filterwarnings( + "ignore", + ".*Converting mask without torch.bool dtype to bool*", +) import appdirs -import click +import depthcharge import github +import lightning import requests +import rich_click as click import torch import tqdm -import yaml -from pytorch_lightning.lite import LightningLite +from lightning.pytorch import seed_everything from . import __version__ from . import utils -from .data import ms_io -from .denovo import model_runner +from .denovo import ModelRunner from .config import Config logger = logging.getLogger("casanovo") +click.rich_click.USE_MARKDOWN = True +click.rich_click.STYLE_HELPTEXT = "" +click.rich_click.SHOW_ARGUMENTS = True -@click.command() -@click.option( - "--mode", +class _SharedParams(click.RichCommand): + """Options shared between most Casanovo commands""" + + def __init__(self, *args, **kwargs) -> None: + """Define shared options.""" + super().__init__(*args, **kwargs) + self.params += [ + click.Option( + ("-m", "--model"), + help=""" + The model weights (.ckpt file). If not provided, Casanovo + will try to download the latest release. + """, + type=click.Path(exists=True, dir_okay=False), + ), + click.Option( + ("-o", "--output"), + help="The mzTab file to which results will be written.", + type=click.Path(dir_okay=False), + ), + click.Option( + ("-c", "--config"), + help=""" + The YAML configuration file overriding the default options. + """, + type=click.Path(exists=True, dir_okay=False), + ), + click.Option( + ("-v", "--verbosity"), + help=""" + Set the verbosity of console logging messages. Log files are + always set to 'debug'. + """, + type=click.Choice( + ["debug", "info", "warning", "error"], + case_sensitive=False, + ), + default="info", + ), + ] + + +@click.group(context_settings=dict(help_option_names=["-h", "--help"])) +def main() -> None: + """# Casanovo + + Casanovo de novo sequences peptides from tandem mass spectra using a + Transformer model. Casanovo currently supports mzML, mzXML, and MGF files + for de novo sequencing and annotated MGF files, such as those from + MassIVE-KB, for training new models. + + Links: + - Documentation: [https://casanovo.readthedocs.io]() + - Official code repository: [https://github.com/Noble-Lab/casanovo]() + + If you use Casanovo in your work, please cite: + - Yilmaz, M., Fondrie, W. E., Bittremieux, W., Oh, S. & Noble, W. S. De novo + mass spectrometry peptide sequencing with a transformer model. Proceedings + of the 39th International Conference on Machine Learning - ICML '22 (2022) + doi:10.1101/2022.02.07.479481. + + """ + return + + +@main.command(cls=_SharedParams) +@click.argument( + "peak_path", required=True, - default="denovo", - help="\b\nThe mode in which to run Casanovo:\n" - '- "denovo" will predict peptide sequences for\nunknown MS/MS spectra.\n' - '- "train" will train a model (from scratch or by\ncontinuing training a ' - "previously trained model).\n" - '- "eval" will evaluate the performance of a\ntrained model using ' - "previously acquired spectrum\nannotations.", - type=click.Choice(["denovo", "train", "eval"]), -) -@click.option( - "--model", - help="The file name of the model weights (.ckpt file).", + nargs=-1, type=click.Path(exists=True, dir_okay=False), ) -@click.option( - "--peak_path", +def sequence( + peak_path: Tuple[str], + model: Optional[str], + config: Optional[str], + output: Optional[str], + verbosity: str, +) -> None: + """De novo sequence peptides from tandem mass spectra. + + PEAK_PATH must be one or more mzMl, mzXML, or MGF files from which + to sequence peptides. + """ + output = setup_logging(output, verbosity) + config, model = setup_model(model, config, output, False) + with ModelRunner(config, model) as runner: + logger.info("Sequencing peptides from:") + for peak_file in peak_path: + logger.info(" %s", peak_file) + + runner.predict(peak_path, output) + + logger.info("DONE!") + + +@main.command(cls=_SharedParams) +@click.argument( + "annotated_peak_path", required=True, - help="The file path with peak files for predicting peptide sequences or " - "training Casanovo.", + nargs=-1, + type=click.Path(exists=True, dir_okay=False), ) -@click.option( - "--peak_path_val", - help="The file path with peak files to be used as validation data during " - "training.", +def evaluate( + annotated_peak_path: Tuple[str], + model: Optional[str], + config: Optional[str], + output: Optional[str], + verbosity: str, +) -> None: + """Evaluate de novo peptide sequencing performance. + + ANNOTATED_PEAK_PATH must be one or more annoated MGF files, + such as those provided by MassIVE-KB. + """ + output = setup_logging(output, verbosity) + config, model = setup_model(model, config, output, False) + with ModelRunner(config, model) as runner: + logger.info("Sequencing and evaluating peptides from:") + for peak_file in annotated_peak_path: + logger.info(" %s", peak_file) + + runner.evaluate(annotated_peak_path) + + logger.info("DONE!") + + +@main.command(cls=_SharedParams) +@click.argument( + "train_peak_path", + required=True, + nargs=-1, + type=click.Path(exists=True, dir_okay=False), ) @click.option( - "--config", - help="The file name of the configuration file with custom options. If not " - "specified, a default configuration will be used.", + "-p", + "--validation_peak_path", + help=""" + An annotated MGF file for validation, like from MassIVE-KB. Use this + option multiple times to specify multiple files. + """, + required=True, + multiple=True, type=click.Path(exists=True, dir_okay=False), ) +def train( + train_peak_path: Tuple[str], + validation_peak_path: Tuple[str], + model: Optional[str], + config: Optional[str], + output: Optional[str], + verbosity: str, +) -> None: + """Train a Casanovo model on your own data. + + TRAIN_PEAK_PATH must be one or more annoated MGF files, such as those + provided by MassIVE-KB, from which to train a new Casnovo model. + """ + output = setup_logging(output, verbosity) + config, model = setup_model(model, config, output, True) + with ModelRunner(config, model) as runner: + logger.info("Training a model from:") + for peak_file in train_peak_path: + logger.info(" %s", peak_file) + + logger.info("Using the following validation files:") + for peak_file in validation_peak_path: + logger.info(" %s", peak_file) + + runner.train(train_peak_path, validation_peak_path) + + logger.info("DONE!") + + +@main.command() +def version() -> None: + """Get the Casanovo version information""" + versions = [ + f"Casanovo: {__version__}", + f"Depthcharge: {depthcharge.__version__}", + f"Lightning: {lightning.__version__}", + f"PyTorch: {torch.__version__}", + ] + sys.stdout.write("\n".join(versions) + "\n") + + +@main.command() @click.option( + "-o", "--output", - help="The base output file name to store logging (extension: .log) and " - "(optionally) prediction results (extension: .mztab).", + help="The output configuration file.", + default="casanovo.yaml", type=click.Path(dir_okay=False), ) -def main( - mode: str, - model: Optional[str], - peak_path: str, - peak_path_val: Optional[str], - config: Optional[str], - output: Optional[str], -): +def configure(output: str) -> None: + """Generate a Casanovo configuration file to customize. + + The casanovo configuration file is in the YAML format. """ - \b - Casanovo: De novo mass spectrometry peptide sequencing with a transformer model. - ================================================================================ + Config.copy_default(output) + output = setup_logging(output, "info") + logger.info(f"Wrote {output}\n") - Yilmaz, M., Fondrie, W. E., Bittremieux, W., Oh, S. & Noble, W. S. De novo - mass spectrometry peptide sequencing with a transformer model. Proceedings - of the 39th International Conference on Machine Learning - ICML '22 (2022) - doi:10.1101/2022.02.07.479481. - Official code website: https://github.com/Noble-Lab/casanovo +def setup_logging( + output: Optional[str], + verbosity: str, +) -> Path: + """Set up the logger. + + Logging occurs to the command-line and to the given log file. + + Parameters + ---------- + output : Optional[str] + The provided output file name. + verbosity : str + The logging level to use in the console. + + Return + ------ + output : Path + The output file path. """ if output is None: - output = os.path.join( - os.getcwd(), - f"casanovo_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}", - ) - else: - basename, ext = os.path.splitext(os.path.abspath(output)) - output = basename if ext.lower() in (".log", ".mztab") else output + output = f"casanovo_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" + + output = Path(output).expanduser().resolve() + + logging_levels = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + } # Configure logging. logging.captureWarnings(True) - root = logging.getLogger() - root.setLevel(logging.DEBUG) + root_logger = logging.getLogger() + root_logger.setLevel(logging.DEBUG) + warnings_logger = logging.getLogger("py.warnings") + + # Formatters for file vs console: + console_formatter = logging.Formatter("{levelname}: {message}", style="{") log_formatter = logging.Formatter( "{asctime} {levelname} [{name}/{processName}] {module}.{funcName} : " "{message}", style="{", ) + console_handler = logging.StreamHandler(sys.stderr) - console_handler.setLevel(logging.DEBUG) - console_handler.setFormatter(log_formatter) - root.addHandler(console_handler) - file_handler = logging.FileHandler(f"{output}.log") + console_handler.setLevel(logging_levels[verbosity.lower()]) + console_handler.setFormatter(console_formatter) + root_logger.addHandler(console_handler) + warnings_logger.addHandler(console_handler) + file_handler = logging.FileHandler(output.with_suffix(".log")) file_handler.setFormatter(log_formatter) - root.addHandler(file_handler) + root_logger.addHandler(file_handler) + warnings_logger.addHandler(file_handler) + # Disable dependency non-critical log messages. - logging.getLogger("depthcharge").setLevel(logging.INFO) + logging.getLogger("depthcharge").setLevel( + logging_levels[verbosity.lower()] + ) + logging.getLogger("fsspec").setLevel(logging.WARNING) logging.getLogger("github").setLevel(logging.WARNING) logging.getLogger("h5py").setLevel(logging.WARNING) logging.getLogger("numba").setLevel(logging.WARNING) @@ -124,13 +320,40 @@ def main( logging.getLogger("torch").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING) + return output + + +def setup_model( + model: Optional[str], + config: Optional[str], + output: Optional[Path], + is_train: bool, +) -> Config: + """Setup Casanovo for most commands. + + Parameters + ---------- + model : Optional[str] + The provided model weights file. + config : Optional[str] + The provided configuration file. + output : Optional[Path] + The provided output file name. + is_train : bool + Are we training? If not, we need to retrieve weights when the model is + None. + + Return + ------ + config : Config + The parsed configuration + """ # Read parameters from the config file. config = Config(config) - - LightningLite.seed_everything(seed=config["random_seed"], workers=True) + seed_everything(seed=config["random_seed"], workers=True) # Download model weights if these were not specified (except when training). - if model is None and mode != "train": + if model is None and not is_train: try: model = _get_model_weights() except github.RateLimitExceededException: @@ -149,28 +372,13 @@ def main( # Log the active configuration. logger.info("Casanovo version %s", str(__version__)) - logger.debug("mode = %s", mode) logger.debug("model = %s", model) - logger.debug("peak_path = %s", peak_path) - logger.debug("peak_path_val = %s", peak_path_val) logger.debug("config = %s", config.file) logger.debug("output = %s", output) for key, value in config.items(): logger.debug("%s = %s", str(key), str(value)) - # Run Casanovo in the specified mode. - if mode == "denovo": - logger.info("Predict peptide sequences with Casanovo.") - writer = ms_io.MztabWriter(f"{output}.mztab") - writer.set_metadata(config, model=model, config_filename=config.file) - model_runner.predict(peak_path, model, config, writer) - writer.save() - elif mode == "eval": - logger.info("Evaluate a trained Casanovo model.") - model_runner.evaluate(peak_path, model, config) - elif mode == "train": - logger.info("Train the Casanovo model.") - model_runner.train(peak_path, peak_path_val, model, config) + return config, model def _get_model_weights() -> str: diff --git a/casanovo/config.py b/casanovo/config.py index 0dfdaf67..0b5a1e4d 100644 --- a/casanovo/config.py +++ b/casanovo/config.py @@ -1,10 +1,10 @@ """Parse the YAML configuration.""" import logging +import shutil from pathlib import Path from typing import Optional, Dict, Callable, Tuple, Union import yaml -import torch from . import utils @@ -53,6 +53,7 @@ class Config: residues=dict, n_log=int, tb_summarywriter=str, + train_label_smoothing=float, warmup_iters=int, max_iters=int, learning_rate=float, @@ -64,11 +65,12 @@ class Config: max_epochs=int, num_sanity_val_steps=int, train_from_scratch=bool, - save_model=bool, + save_top_k=int, model_save_folder_path=str, - save_weights_only=bool, - every_n_train_steps=int, - no_gpu=bool, + val_check_interval=int, + calculate_precision=bool, + accelerator=str, + devices=int, ) def __init__(self, config_file: Optional[str] = None): @@ -82,18 +84,25 @@ def __init__(self, config_file: Optional[str] = None): else: with Path(config_file).open() as f_in: self._user_config = yaml.safe_load(f_in) - + # Check for missing entries in config file. + config_missing = self._params.keys() - self._user_config.keys() + if len(config_missing) > 0: + raise KeyError( + "Missing expected config option(s): " + f"{', '.join(config_missing)}" + ) + # Check for unrecognized config file entries. + config_unknown = self._user_config.keys() - self._params.keys() + if len(config_unknown) > 0: + raise KeyError( + "Unrecognized config option(s): " + f"{', '.join(config_unknown)}" + ) # Validate: for key, val in self._config_types.items(): self.validate_param(key, val) - # Add extra configuration options and scale by the number of GPUs. - n_gpus = 0 if self["no_gpu"] else torch.cuda.device_count() self._params["n_workers"] = utils.n_workers() - if n_gpus > 1: - self._params["train_batch_size"] = ( - self["train_batch_size"] // n_gpus - ) def __getitem__(self, param: str) -> Union[int, bool, str, Tuple, Dict]: """Retrieve a parameter""" @@ -133,3 +142,14 @@ def validate_param(self, param: str, param_type: Callable): def items(self) -> Tuple[str, ...]: """Return the parameters""" return self._params.items() + + @classmethod + def copy_default(cls, output: str) -> None: + """Copy the default YAML configuration. + + Parameters + ---------- + output : str + The output file. + """ + shutil.copyfile(cls._default_config, output) diff --git a/casanovo/config.yaml b/casanovo/config.yaml index 0e5c6b95..896f67bc 100644 --- a/casanovo/config.yaml +++ b/casanovo/config.yaml @@ -1,16 +1,57 @@ ### # Casanovo configuration. # Blank entries are interpreted as "None". -# Parameters that can be modified when running inference with Casanovo, -# i.e. denovo and eval modes in the command line interface, are marked with -# "(I)". Other parameters shouldn't be changed unless a new Casanovo model -# is being trained. ### -# Random seed to ensure reproducible results. +### +# The following parameters can be modified when running inference or +# when fine-tuning an existing Casanovo model. +### + +# Max absolute difference allowed with respect to observed precursor m/z +# Predictions outside the tolerance range are assigned a negative peptide score. +precursor_mass_tol: 50 # ppm +# Isotopes to consider when comparing predicted and observed precursor m/z's +isotope_error_range: [0, 1] +# The minimum length of predicted peptides +min_peptide_len: 6 +# Number of spectra in one inference batch +predict_batch_size: 1024 +# Number of beams used in beam search +n_beams: 1 +# Number of PSMs for each spectrum +top_match: 1 +# The hardware accelerator to use. Must be one of: +# "cpu", "gpu", "tpu", "ipu", "hpu", "mps", or "auto" +accelerator: "auto" +# The devices to use. Can be set to a positive number int, +# or the value -1 to indicate all available devices should be used, +# If left empty, the appropriate number will be automatically +# selected for automatic selected on the chosen accelerator. +devices: + +### +# The following parameters should only be modified if you are training a new +# Casanovo model from scratch. +### + +# Random seed to ensure reproducible results random_seed: 454 -# Spectrum processing options. +# OUTPUT OPTIONS +# Logging frequency in training steps +n_log: 1 +# Tensorboard directory to use for keeping track of training metrics +tb_summarywriter: +# Save the top k model checkpoints during training. -1 saves all, and +# leaving this field empty saves none. +save_top_k: 5 +# Path to saved checkpoints +model_save_folder_path: "" +# Model validation and checkpointing frequency in training steps +val_check_interval: 50_000 + +# SPECTRUM PROCESSING OPTIONS # Number of the most intense peaks to retain, any remaining peaks are discarded n_peaks: 150 # Min peak m/z allowed, peaks with smaller m/z are discarded @@ -23,15 +64,8 @@ min_intensity: 0.01 remove_precursor_tol: 2.0 # Da # Max precursor charge allowed, spectra with larger charge are skipped max_charge: 10 -# Max absolute difference allowed with respect to observed precursor m/z (I) -# Predictions outside the tolerance range are assinged a negative peptide score -precursor_mass_tol: 50 # ppm -# Isotopes to consider when comparing predicted and observed precursor m/z's (I) -isotope_error_range: [0, 1] -# The minimum length of predicted peptides (I). -min_peptide_len: 6 -# Model architecture options. +# MODEL ARCHITECTURE OPTIONS # Dimensionality of latent representations, i.e. peak embeddings dim_model: 512 # Number of attention heads @@ -45,12 +79,33 @@ dropout: 0.0 # Number of dimensions to use for encoding peak intensity # Projected up to ``dim_model`` by default and summed with the peak m/z encoding dim_intensity: -# Option to provide a pre-trained spectrum encoder when training -# Trained from scratch by default -custom_encoder: # Max decoded peptide length max_length: 100 -# Amino acid and modification vocabulary to use +# Number of warmup iterations for learning rate scheduler +warmup_iters: 100_000 +# Max number of iterations for learning rate scheduler +max_iters: 600_000 +# Learning rate for weight updates during training +learning_rate: 5e-4 +# Regularization term for weight updates +weight_decay: 1e-5 +# Amount of label smoothing when computing the training loss +train_label_smoothing: 0.01 + +# TRAINING/INFERENCE OPTIONS +# Number of spectra in one training batch +train_batch_size: 32 +# Max number of training epochs +max_epochs: 30 +# Number of validation steps to run before training begins +num_sanity_val_steps: 0 +# Set to "False" to further train a pre-trained Casanovo model +train_from_scratch: True +# Calculate peptide and amino acid precision during training. this +# is expensive, so we recommend against it. +calculate_precision: False + +# AMINO ACID AND MODIFICATION VOCABULARY residues: "G": 57.021464 "A": 71.037114 @@ -81,43 +136,3 @@ residues: "+43.006": 43.005814 # Carbamylation "-17.027": -17.026549 # NH3 loss "+43.006-17.027": 25.980265 # Carbamylation and NH3 loss -# Logging frequency in training steps -n_log: 1 -# Tensorboard object to keep track of training metrics -tb_summarywriter: -# Number of warmup iterations for learning rate scheduler -warmup_iters: 100_000 -# Max number of iterations for learning rate scheduler -max_iters: 600_000 -# Learning rate for weight updates during training -learning_rate: 5e-4 -# Regularization term for weight updates -weight_decay: 1e-5 - -# Training/inference options. -# Number of spectra in one training batch -train_batch_size: 32 -# Number of spectra in one inference batch (I) -predict_batch_size: 1024 -# Number of beams used in beam search (I) -n_beams: 5 -# Number of PSMs for each spectrum (I) -top_match: 1 -# Object for logging training progress -logger: -# Max number of training epochs -max_epochs: 30 -# Number of validation steps to run before training begins -num_sanity_val_steps: 0 -# Set to "False" to further train a pre-trained Casanovo model -train_from_scratch: True -# Save model checkpoints during training -save_model: True -# Path to saved checkpoints -model_save_folder_path: "" -# Set to "False" to save the PyTorch model instance -save_weights_only: True -# Model validation and checkpointing frequency in training steps -every_n_train_steps: 50_000 -# Disable usage of a GPU (including Apple MPS): -no_gpu: False diff --git a/casanovo/data/ms_io.py b/casanovo/data/ms_io.py index e3b3a8d6..47d99700 100644 --- a/casanovo/data/ms_io.py +++ b/casanovo/data/ms_io.py @@ -5,11 +5,12 @@ import os import re from pathlib import Path -from typing import Any, Dict, List +from typing import List import natsort from .. import __version__ +from ..config import Config class MztabWriter: @@ -42,13 +43,13 @@ def __init__(self, filename: str): self._run_map = {} self.psms = [] - def set_metadata(self, config: Dict[str, Any], **kwargs) -> None: + def set_metadata(self, config: Config, **kwargs) -> None: """ Specify metadata information to write to the mzTab header. Parameters ---------- - config : Dict[str, Any] + config : Config The active configuration options. kwargs Additional configuration options (i.e. from command-line arguments). diff --git a/casanovo/denovo/__init__.py b/casanovo/denovo/__init__.py index e69de29b..da194f1b 100644 --- a/casanovo/denovo/__init__.py +++ b/casanovo/denovo/__init__.py @@ -0,0 +1 @@ +from .model_runner import ModelRunner diff --git a/casanovo/denovo/dataloaders.py b/casanovo/denovo/dataloaders.py index 2ee2f8f5..998fa66a 100644 --- a/casanovo/denovo/dataloaders.py +++ b/casanovo/denovo/dataloaders.py @@ -3,8 +3,8 @@ import os from typing import List, Optional, Tuple +import lightning.pytorch as pl import numpy as np -import pytorch_lightning as pl import torch from depthcharge.data import AnnotatedSpectrumIndex @@ -23,8 +23,10 @@ class DeNovoDataModule(pl.LightningDataModule): The spectrum index file corresponding to the validation data. test_index : Optional[AnnotatedSpectrumIndex] The spectrum index file corresponding to the testing data. - batch_size : int - The batch size to use for training and evaluating. + train_batch_size : int + The batch size to use for training. + eval_batch_size : int + The batch size to use for inference. n_peaks : Optional[int] The number of top-n most intense peaks to keep in each spectrum. `None` retains all peaks. @@ -52,7 +54,8 @@ def __init__( train_index: Optional[AnnotatedSpectrumIndex] = None, valid_index: Optional[AnnotatedSpectrumIndex] = None, test_index: Optional[AnnotatedSpectrumIndex] = None, - batch_size: int = 128, + train_batch_size: int = 128, + eval_batch_size: int = 1028, n_peaks: Optional[int] = 150, min_mz: float = 50.0, max_mz: float = 2500.0, @@ -65,7 +68,8 @@ def __init__( self.train_index = train_index self.valid_index = valid_index self.test_index = test_index - self.batch_size = batch_size + self.train_batch_size = train_batch_size + self.eval_batch_size = eval_batch_size self.n_peaks = n_peaks self.min_mz = min_mz self.max_mz = max_mz @@ -119,7 +123,10 @@ def setup(self, stage: str = None, annotated: bool = True) -> None: self.test_dataset = make_dataset(self.test_index) def _make_loader( - self, dataset: torch.utils.data.Dataset + self, + dataset: torch.utils.data.Dataset, + batch_size: int, + shuffle: bool = False, ) -> torch.utils.data.DataLoader: """ Create a PyTorch DataLoader. @@ -128,6 +135,10 @@ def _make_loader( ---------- dataset : torch.utils.data.Dataset A PyTorch Dataset. + batch_size : int + The batch size to use. + shuffle : bool + Option to shuffle the batches. Returns ------- @@ -136,27 +147,30 @@ def _make_loader( """ return torch.utils.data.DataLoader( dataset, - batch_size=self.batch_size, + batch_size=batch_size, collate_fn=prepare_batch, pin_memory=True, num_workers=self.n_workers, + shuffle=shuffle, ) def train_dataloader(self) -> torch.utils.data.DataLoader: """Get the training DataLoader.""" - return self._make_loader(self.train_dataset) + return self._make_loader( + self.train_dataset, self.train_batch_size, shuffle=True + ) def val_dataloader(self) -> torch.utils.data.DataLoader: """Get the validation DataLoader.""" - return self._make_loader(self.valid_dataset) + return self._make_loader(self.valid_dataset, self.eval_batch_size) def test_dataloader(self) -> torch.utils.data.DataLoader: """Get the test DataLoader.""" - return self._make_loader(self.test_dataset) + return self._make_loader(self.test_dataset, self.eval_batch_size) def predict_dataloader(self) -> torch.utils.data.DataLoader: """Get the predict DataLoader.""" - return self._make_loader(self.test_dataset) + return self._make_loader(self.test_dataset, self.eval_batch_size) def prepare_batch( diff --git a/casanovo/denovo/evaluate.py b/casanovo/denovo/evaluate.py index 25bb9984..75ac4b6a 100644 --- a/casanovo/denovo/evaluate.py +++ b/casanovo/denovo/evaluate.py @@ -278,7 +278,7 @@ def aa_match_metrics( pep_precision = sum([aa_matches[1] for aa_matches in aa_matches_batch]) / ( len(aa_matches_batch) + 1e-8 ) - return aa_precision, aa_recall, pep_precision + return float(aa_precision), float(aa_recall), float(pep_precision) def aa_precision_recall( diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index 8b3ae3e0..39d2027a 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -1,15 +1,14 @@ """A de novo peptide sequencing model.""" import collections import heapq -import itertools import logging from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import depthcharge.masses import einops -import numpy as np -import pytorch_lightning as pl import torch +import numpy as np +import lightning.pytorch as pl from torch.utils.tensorboard import SummaryWriter from depthcharge.components import ModelMixin, PeptideDecoder, SpectrumEncoder @@ -44,9 +43,6 @@ class Spec2Pep(pl.LightningModule, ModelMixin): (``dim_model - dim_intensity``) are reserved for encoding the m/z value. If ``None``, the intensity will be projected up to ``dim_model`` using a linear layer, then summed with the m/z encoding for each peak. - custom_encoder : Optional[Union[SpectrumEncoder, PairedSpectrumEncoder]] - A pretrained encoder to use. The ``dim_model`` of the encoder must be - the same as that specified by the ``dim_model`` parameter here. max_length : int The maximum peptide length to decode. residues: Union[Dict[str, float], str] @@ -77,12 +73,17 @@ class Spec2Pep(pl.LightningModule, ModelMixin): tb_summarywriter: Optional[str] Folder path to record performance metrics during training. If ``None``, don't use a ``SummaryWriter``. + train_label_smoothing: float + Smoothing factor when calculating the training loss. warmup_iters: int The number of warm up iterations for the learning rate scheduler. max_iters: int The total number of iterations for the learning rate scheduler. out_writer: Optional[str] The output writer for the prediction results. + calculate_precision: bool + Calculate the validation set precision during training. + This is expensive. **kwargs : Dict Additional keyword arguments passed to the Adam optimizer. """ @@ -95,38 +96,37 @@ def __init__( n_layers: int = 9, dropout: float = 0.0, dim_intensity: Optional[int] = None, - custom_encoder: Optional[SpectrumEncoder] = None, max_length: int = 100, residues: Union[Dict[str, float], str] = "canonical", max_charge: int = 5, precursor_mass_tol: float = 50, isotope_error_range: Tuple[int, int] = (0, 1), min_peptide_len: int = 6, - n_beams: int = 5, + n_beams: int = 1, top_match: int = 1, n_log: int = 10, tb_summarywriter: Optional[ torch.utils.tensorboard.SummaryWriter ] = None, + train_label_smoothing: float = 0.01, warmup_iters: int = 100_000, max_iters: int = 600_000, out_writer: Optional[ms_io.MztabWriter] = None, + calculate_precision: bool = False, **kwargs: Dict, ): super().__init__() + self.save_hyperparameters() # Build the model. - if custom_encoder is not None: - self.encoder = custom_encoder - else: - self.encoder = SpectrumEncoder( - dim_model=dim_model, - n_head=n_head, - dim_feedforward=dim_feedforward, - n_layers=n_layers, - dropout=dropout, - dim_intensity=dim_intensity, - ) + self.encoder = SpectrumEncoder( + dim_model=dim_model, + n_head=n_head, + dim_feedforward=dim_feedforward, + n_layers=n_layers, + dropout=dropout, + dim_intensity=dim_intensity, + ) self.decoder = PeptideDecoder( dim_model=dim_model, n_head=n_head, @@ -137,7 +137,10 @@ def __init__( max_charge=max_charge, ) self.softmax = torch.nn.Softmax(2) - self.celoss = torch.nn.CrossEntropyLoss(ignore_index=0) + self.celoss = torch.nn.CrossEntropyLoss( + ignore_index=0, label_smoothing=train_label_smoothing + ) + self.val_celoss = torch.nn.CrossEntropyLoss(ignore_index=0) # Optimizer settings. self.warmup_iters = warmup_iters self.max_iters = max_iters @@ -157,6 +160,7 @@ def __init__( self.stop_token = self.decoder._aa2idx["$"] # Logging. + self.calculate_precision = calculate_precision self.n_log = n_log self._history = [] if tb_summarywriter is not None: @@ -604,9 +608,12 @@ def _get_topk_beams( # Mask out terminated beams. Include precursor m/z tolerance induced # termination. + # TODO: `clone()` is necessary to get the correct output with n_beams=1. + # An alternative implementation using base PyTorch instead of einops + # might be more efficient. finished_mask = einops.repeat( finished_beams, "(B S) -> B (V S)", S=beam, V=vocab - ) + ).clone() # Mask out the index '0', i.e. padding token, by default. finished_mask[:, :beam] = True @@ -722,10 +729,13 @@ def training_step( """ pred, truth = self._forward_step(*batch) pred = pred[:, :-1, :].reshape(-1, self.decoder.vocab_size + 1) - loss = self.celoss(pred, truth.flatten()) + if mode == "train": + loss = self.celoss(pred, truth.flatten()) + else: + loss = self.val_celoss(pred, truth.flatten()) self.log( - "CELoss", - {mode: loss.detach()}, + f"{mode}_CELoss", + loss.detach(), on_step=False, on_epoch=True, sync_dist=True, @@ -751,6 +761,8 @@ def validation_step( """ # Record the loss. loss = self.training_step(batch, mode="valid") + if not self.calculate_precision: + return loss # Calculate and log amino acid and peptide match evaluation metrics from # the predicted peptides. @@ -758,21 +770,25 @@ def validation_step( for spectrum_preds in self.forward(batch[0], batch[1]): for _, _, pred in spectrum_preds: peptides_pred.append(pred) + aa_precision, _, pep_precision = evaluate.aa_match_metrics( *evaluate.aa_match_batch( - peptides_pred, peptides_true, self.decoder._peptide_mass.masses + peptides_true, + peptides_pred, + self.decoder._peptide_mass.masses, ) ) log_args = dict(on_step=False, on_epoch=True, sync_dist=True) self.log( "Peptide precision at coverage=1", - {"valid": pep_precision}, + pep_precision, **log_args, ) self.log( - "AA precision at coverage=1", {"valid": aa_precision}, **log_args + "AA precision at coverage=1", + aa_precision, + **log_args, ) - return loss def predict_step( @@ -824,10 +840,10 @@ def on_train_epoch_end(self) -> None: """ Log the training loss at the end of each epoch. """ - train_loss = self.trainer.callback_metrics["CELoss"]["train"].detach() + train_loss = self.trainer.callback_metrics["train_CELoss"].detach() metrics = { "step": self.trainer.global_step, - "train": train_loss, + "train": train_loss.item(), } self._history.append(metrics) self._log_history() @@ -839,19 +855,25 @@ def on_validation_epoch_end(self) -> None: callback_metrics = self.trainer.callback_metrics metrics = { "step": self.trainer.global_step, - "valid": callback_metrics["CELoss"]["valid"].detach(), - "valid_aa_precision": callback_metrics[ - "AA precision at coverage=1" - ]["valid"].detach(), - "valid_pep_precision": callback_metrics[ - "Peptide precision at coverage=1" - ]["valid"].detach(), + "valid": callback_metrics["valid_CELoss"].detach().item(), } + + if self.calculate_precision: + metrics["valid_aa_precision"] = ( + callback_metrics["AA precision at coverage=1"].detach().item() + ) + metrics["valid_pep_precision"] = ( + callback_metrics["Peptide precision at coverage=1"] + .detach() + .item() + ) self._history.append(metrics) self._log_history() - def on_predict_epoch_end( - self, results: List[List[Tuple[np.ndarray, List[str], torch.Tensor]]] + def on_predict_batch_end( + self, + outputs: List[Tuple[np.ndarray, List[str], torch.Tensor]], + *args, ) -> None: """ Write the predicted peptide sequences and amino acid scores to the @@ -867,9 +889,7 @@ def on_predict_epoch_end( peptide, peptide_score, aa_scores, - ) in itertools.chain.from_iterable( - itertools.chain.from_iterable(results) - ): + ) in outputs: if len(peptide) == 0: continue self.out_writer.psms.append( @@ -892,19 +912,28 @@ def _log_history(self) -> None: if len(self._history) == 0: return if len(self._history) == 1: - logger.info( - "Step\tTrain loss\tValid loss\tPeptide precision\tAA precision" - ) + header = "Step\tTrain loss\tValid loss\t" + if self.calculate_precision: + header += "Peptide precision\tAA precision" + + logger.info(header) metrics = self._history[-1] if metrics["step"] % self.n_log == 0: - logger.info( - "%i\t%.6f\t%.6f\t%.6f\t%.6f", + msg = "%i\t%.6f\t%.6f" + vals = [ metrics["step"], metrics.get("train", np.nan), metrics.get("valid", np.nan), - metrics.get("valid_pep_precision", np.nan), - metrics.get("valid_aa_precision", np.nan), - ) + ] + + if self.calculate_precision: + msg += "\t%.6f\t%.6f" + vals += [ + metrics.get("valid_pep_precision", np.nan), + metrics.get("valid_aa_precision", np.nan), + ] + + logger.info(msg, *vals) if self.tb_summarywriter is not None: for descr, key in [ ("loss/train_crossentropy_loss", "train"), diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index 223c14c7..c7a9cab6 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -2,19 +2,21 @@ model.""" import glob import logging -import operator import os import tempfile import uuid -from typing import Any, Dict, Iterable, List, Optional, Union +import warnings +from pathlib import Path +from typing import Iterable, List, Optional, Union +import lightning.pytorch as pl import numpy as np -import pytorch_lightning as pl import torch from depthcharge.data import AnnotatedSpectrumIndex, SpectrumIndex -from pytorch_lightning.strategies import DDPStrategy +from lightning.pytorch.strategies import DDPStrategy +from lightning.pytorch.callbacks import ModelCheckpoint -from .. import utils +from ..config import Config from ..data import ms_io from ..denovo.dataloaders import DeNovoDataModule from ..denovo.model import Spec2Pep @@ -23,307 +25,398 @@ logger = logging.getLogger("casanovo") -def predict( - peak_path: str, - model_filename: str, - config: Dict[str, Any], - out_writer: ms_io.MztabWriter, -) -> None: - """ - Predict peptide sequences with a trained Casanovo model. +class ModelRunner: + """A class to run Casanovo models. Parameters ---------- - peak_path : str - The path with peak files for predicting peptide sequences. - model_filename : str - The file name of the model weights (.ckpt file). - config : Dict[str, Any] - The configuration options. - out_writer : ms_io.MztabWriter - The mzTab writer to export the prediction results. - """ - _execute_existing(peak_path, model_filename, config, False, out_writer) - - -def evaluate( - peak_path: str, model_filename: str, config: Dict[str, Any] -) -> None: + config : Config object + The casanovo configuration. + model_filename : str, optional + The model filename is required for eval and de novo modes, + but not for training a model from scratch. """ - Evaluate peptide sequence predictions from a trained Casanovo model. - Parameters - ---------- - peak_path : str - The path with peak files for predicting peptide sequences. - model_filename : str - The file name of the model weights (.ckpt file). - config : Dict[str, Any] - The configuration options. - """ - _execute_existing(peak_path, model_filename, config, True) + def __init__( + self, + config: Config, + model_filename: Optional[str] = None, + ) -> None: + """Initialize a ModelRunner""" + self.config = config + self.model_filename = model_filename + + # Initialized later: + self.tmp_dir = None + self.trainer = None + self.model = None + self.loaders = None + self.writer = None + + # Configure checkpoints. + if config.save_top_k is not None: + self.callbacks = [ + ModelCheckpoint( + dirpath=config.model_save_folder_path, + monitor="valid_CELoss", + mode="min", + save_top_k=config.save_top_k, + ) + ] + else: + self.callbacks = None + + def __enter__(self): + """Enter the context manager""" + self.tmp_dir = tempfile.TemporaryDirectory() + return self + + def __exit__(self, exc_type, exc_value, traceback): + """Cleanup on exit""" + self.tmp_dir.cleanup() + self.tmp_dir = None + if self.writer is not None: + self.writer.save() + + def train( + self, + train_peak_path: Iterable[str], + valid_peak_path: Iterable[str], + ) -> None: + """Train the Casanovo model. + + Parameters + ---------- + train_peak_path : iterable of str + The path to the MS data files for training. + valid_peak_path : iterable of str + The path to the MS data files for validation. + + Returns + ------- + self + """ + self.initialize_trainer(train=True) + self.initialize_model(train=True) + + train_index = self._get_index(train_peak_path, True, "training") + valid_index = self._get_index(valid_peak_path, True, "validation") + self.initialize_data_module(train_index, valid_index) + self.loaders.setup() + + self.trainer.fit( + self.model, + self.loaders.train_dataloader(), + self.loaders.val_dataloader(), + ) + def evaluate(self, peak_path: Iterable[str]) -> None: + """Evaluate peptide sequence preditions from a trained Casanovo model. + + Parameters + ---------- + peak_path : iterable of str + The path with MS data files for predicting peptide sequences. + + Returns + ------- + self + """ + self.initialize_trainer(train=False) + self.initialize_model(train=False) + + test_index = self._get_index(peak_path, True, "evaluation") + self.initialize_data_module(test_index=test_index) + self.loaders.setup(stage="test", annotated=True) + + self.trainer.validate(self.model, self.loaders.test_dataloader()) + + def predict(self, peak_path: Iterable[str], output: str) -> None: + """Predict peptide sequences with a trained Casanovo model. + + Parameters + ---------- + peak_path : iterable of str + The path with the MS data files for predicting peptide sequences. + output : str + Where should the output be saved? + + Returns + ------- + self + """ + self.writer = ms_io.MztabWriter(Path(output).with_suffix(".mztab")) + self.writer.set_metadata( + self.config, + model=str(self.model_filename), + config_filename=self.config.file, + ) -def _execute_existing( - peak_path: str, - model_filename: str, - config: Dict[str, Any], - annotated: bool, - out_writer: Optional[ms_io.MztabWriter] = None, -) -> None: - """ - Predict peptide sequences with a trained Casanovo model with/without - evaluation. + self.initialize_trainer(train=False) + self.initialize_model(train=False) + self.model.out_writer = self.writer + + test_index = self._get_index(peak_path, False, "") + self.writer.set_ms_run(test_index.ms_files) + self.initialize_data_module(test_index=test_index) + self.loaders.setup(stage="test", annotated=False) + self.trainer.predict(self.model, self.loaders.test_dataloader()) + + def initialize_trainer(self, train: bool) -> None: + """Initialize the lightning Trainer. + + Parameters + ---------- + train : bool + Determines whether to set the trainer up for model training + or evaluation / inference. + """ + trainer_cfg = dict( + accelerator=self.config.accelerator, + devices=1, + enable_checkpointing=False, + ) - Parameters - ---------- - peak_path : str - The path with peak files for predicting peptide sequences. - model_filename : str - The file name of the model weights (.ckpt file). - config : Dict[str, Any] - The configuration options. - annotated : bool - Whether the input peak files are annotated (execute in evaluation mode) - or not (execute in prediction mode only). - out_writer : Optional[ms_io.MztabWriter] - The mzTab writer to export the prediction results. - """ - # Load the trained model. - if not os.path.isfile(model_filename): - logger.error( - "Could not find the trained model weights at file %s", - model_filename, + if train: + if self.config.devices is None: + devices = "auto" + else: + devices = self.config.devices + + additional_cfg = dict( + devices=devices, + callbacks=self.callbacks, + enable_checkpointing=self.config.save_top_k is not None, + max_epochs=self.config.max_epochs, + num_sanity_val_steps=self.config.num_sanity_val_steps, + strategy=self._get_strategy(), + val_check_interval=self.config.val_check_interval, + check_val_every_n_epoch=None, + ) + trainer_cfg.update(additional_cfg) + + self.trainer = pl.Trainer(**trainer_cfg) + + def initialize_model(self, train: bool) -> None: + """Initialize the Casanovo model. + + Parameters + ---------- + train : bool + Determines whether to set the model up for model training + or evaluation / inference. + """ + model_params = dict( + dim_model=self.config.dim_model, + n_head=self.config.n_head, + dim_feedforward=self.config.dim_feedforward, + n_layers=self.config.n_layers, + dropout=self.config.dropout, + dim_intensity=self.config.dim_intensity, + max_length=self.config.max_length, + residues=self.config.residues, + max_charge=self.config.max_charge, + precursor_mass_tol=self.config.precursor_mass_tol, + isotope_error_range=self.config.isotope_error_range, + min_peptide_len=self.config.min_peptide_len, + n_beams=self.config.n_beams, + top_match=self.config.top_match, + n_log=self.config.n_log, + tb_summarywriter=self.config.tb_summarywriter, + train_label_smoothing=self.config.train_label_smoothing, + warmup_iters=self.config.warmup_iters, + max_iters=self.config.max_iters, + lr=self.config.learning_rate, + weight_decay=self.config.weight_decay, + out_writer=self.writer, + calculate_precision=self.config.calculate_precision, ) - raise FileNotFoundError("Could not find the trained model weights") - model = Spec2Pep().load_from_checkpoint( - model_filename, - dim_model=config["dim_model"], - n_head=config["n_head"], - dim_feedforward=config["dim_feedforward"], - n_layers=config["n_layers"], - dropout=config["dropout"], - dim_intensity=config["dim_intensity"], - custom_encoder=config["custom_encoder"], - max_length=config["max_length"], - residues=config["residues"], - max_charge=config["max_charge"], - precursor_mass_tol=config["precursor_mass_tol"], - isotope_error_range=config["isotope_error_range"], - min_peptide_len=config["min_peptide_len"], - n_beams=config["n_beams"], - top_match=config["top_match"], - n_log=config["n_log"], - out_writer=out_writer, - ) - # Read the MS/MS spectra for which to predict peptide sequences. - if annotated: - peak_ext = (".mgf", ".h5", ".hdf5") - else: - peak_ext = (".mgf", ".mzml", ".mzxml", ".h5", ".hdf5") - if len(peak_filenames := _get_peak_filenames(peak_path, peak_ext)) == 0: - logger.error("Could not find peak files from %s", peak_path) - raise FileNotFoundError("Could not find peak files") - elif out_writer is not None: - out_writer.set_ms_run(peak_filenames) - peak_is_index = any( - [os.path.splitext(fn)[1] in (".h5", ".hdf5") for fn in peak_filenames] - ) - if peak_is_index and len(peak_filenames) > 1: - logger.error("Multiple HDF5 spectrum indexes specified") - raise ValueError("Multiple HDF5 spectrum indexes specified") - tmp_dir = tempfile.TemporaryDirectory() - if peak_is_index: - idx_filename, peak_filenames = peak_filenames[0], None - else: - idx_filename = os.path.join(tmp_dir.name, f"{uuid.uuid4().hex}.hdf5") - SpectrumIdx = AnnotatedSpectrumIndex if annotated else SpectrumIndex - valid_charge = np.arange(1, config["max_charge"] + 1) - index = SpectrumIdx( - idx_filename, peak_filenames, valid_charge=valid_charge - ) - # Initialize the data loader. - loaders = DeNovoDataModule( - test_index=index, - n_peaks=config["n_peaks"], - min_mz=config["min_mz"], - max_mz=config["max_mz"], - min_intensity=config["min_intensity"], - remove_precursor_tol=config["remove_precursor_tol"], - n_workers=config["n_workers"], - batch_size=config["predict_batch_size"], - ) - loaders.setup(stage="test", annotated=annotated) - - # Create the Trainer object. - trainer = pl.Trainer( - accelerator="auto", - auto_select_gpus=True, - devices=_get_devices(config["no_gpu"]), - logger=config["logger"], - max_epochs=config["max_epochs"], - num_sanity_val_steps=config["num_sanity_val_steps"], - strategy=_get_strategy(), - ) - # Run the model with/without validation. - run_trainer = trainer.validate if annotated else trainer.predict - run_trainer(model, loaders.test_dataloader()) - # Clean up temporary files. - tmp_dir.cleanup() - - -def train( - peak_path: str, - peak_path_val: str, - model_filename: str, - config: Dict[str, Any], -) -> None: - """ - Train a Casanovo model. - The model can be trained from scratch or by continuing training an existing - model. + # Reconfigurable non-architecture related parameters for a loaded model + loaded_model_params = dict( + max_length=self.config.max_length, + precursor_mass_tol=self.config.precursor_mass_tol, + isotope_error_range=self.config.isotope_error_range, + n_beams=self.config.n_beams, + min_peptide_len=self.config.min_peptide_len, + top_match=self.config.top_match, + n_log=self.config.n_log, + tb_summarywriter=self.config.tb_summarywriter, + train_label_smoothing=self.config.train_label_smoothing, + warmup_iters=self.config.warmup_iters, + max_iters=self.config.max_iters, + lr=self.config.learning_rate, + weight_decay=self.config.weight_decay, + out_writer=self.writer, + calculate_precision=self.config.calculate_precision, + ) - Parameters - ---------- - peak_path : str - The path with peak files to be used as training data. - peak_path_val : str - The path with peak files to be used as validation data. - model_filename : str - The file name of the model weights (.ckpt file). - config : Dict[str, Any] - The configuration options. - """ - # Read the MS/MS spectra to use for training and validation. - ext = (".mgf", ".h5", ".hdf5") - if len(train_filenames := _get_peak_filenames(peak_path, ext)) == 0: - logger.error("Could not find training peak files from %s", peak_path) - raise FileNotFoundError("Could not find training peak files") - train_is_index = any( - [os.path.splitext(fn)[1] in (".h5", ".hdf5") for fn in train_filenames] - ) - if train_is_index and len(train_filenames) > 1: - logger.error("Multiple training HDF5 spectrum indexes specified") - raise ValueError("Multiple training HDF5 spectrum indexes specified") - if ( - peak_path_val is None - or len(val_filenames := _get_peak_filenames(peak_path_val, ext)) == 0 - ): - logger.error( - "Could not find validation peak files from %s", peak_path_val + from_scratch = ( + self.config.train_from_scratch, + self.model_filename is None, ) - raise FileNotFoundError("Could not find validation peak files") - val_is_index = any( - [os.path.splitext(fn)[1] in (".h5", ".hdf5") for fn in val_filenames] - ) - if val_is_index and len(val_filenames) > 1: - logger.error("Multiple validation HDF5 spectrum indexes specified") - raise ValueError("Multiple validation HDF5 spectrum indexes specified") - tmp_dir = tempfile.TemporaryDirectory() - if train_is_index: - train_idx_fn, train_filenames = train_filenames[0], None - else: - train_idx_fn = os.path.join(tmp_dir.name, f"{uuid.uuid4().hex}.hdf5") - valid_charge = np.arange(1, config["max_charge"] + 1) - train_index = AnnotatedSpectrumIndex( - train_idx_fn, train_filenames, valid_charge=valid_charge - ) - if val_is_index: - val_idx_fn, val_filenames = val_filenames[0], None - else: - val_idx_fn = os.path.join(tmp_dir.name, f"{uuid.uuid4().hex}.hdf5") - val_index = AnnotatedSpectrumIndex( - val_idx_fn, val_filenames, valid_charge=valid_charge - ) - # Initialize the data loaders. - dataloader_params = dict( - batch_size=config["train_batch_size"], - n_peaks=config["n_peaks"], - min_mz=config["min_mz"], - max_mz=config["max_mz"], - min_intensity=config["min_intensity"], - remove_precursor_tol=config["remove_precursor_tol"], - n_workers=config["n_workers"], - ) - train_loader = DeNovoDataModule( - train_index=train_index, **dataloader_params - ) - train_loader.setup() - val_loader = DeNovoDataModule(valid_index=val_index, **dataloader_params) - val_loader.setup() - # Initialize the model. - model_params = dict( - dim_model=config["dim_model"], - n_head=config["n_head"], - dim_feedforward=config["dim_feedforward"], - n_layers=config["n_layers"], - dropout=config["dropout"], - dim_intensity=config["dim_intensity"], - custom_encoder=config["custom_encoder"], - max_length=config["max_length"], - residues=config["residues"], - max_charge=config["max_charge"], - precursor_mass_tol=config["precursor_mass_tol"], - isotope_error_range=config["isotope_error_range"], - n_beams=config["n_beams"], - top_match=config["top_match"], - n_log=config["n_log"], - tb_summarywriter=config["tb_summarywriter"], - warmup_iters=config["warmup_iters"], - max_iters=config["max_iters"], - lr=config["learning_rate"], - weight_decay=config["weight_decay"], - ) - if config["train_from_scratch"]: - model = Spec2Pep(**model_params) - else: - if not os.path.isfile(model_filename): + if train and any(from_scratch): + self.model = Spec2Pep(**model_params) + return + elif self.model_filename is None: + logger.error("A model file must be provided") + raise ValueError("A model file must be provided") + + if not Path(self.model_filename).exists(): logger.error( - "Could not find the model weights at file %s to continue " - "training", - model_filename, + "Could not find the model weights at file %s", + self.model_filename, ) - raise FileNotFoundError( - "Could not find the model weights to continue training" + raise FileNotFoundError("Could not find the model weights file") + + # First try loading model details from the weights file, otherwise use + # the provided configuration. + device = torch.empty(1).device # Use the default device. + try: + self.model = Spec2Pep.load_from_checkpoint( + self.model_filename, map_location=device, **loaded_model_params ) - model = Spec2Pep().load_from_checkpoint(model_filename, **model_params) - # Create the Trainer object and (optionally) a checkpoint callback to - # periodically save the model. - if config["save_model"]: - callbacks = [ - pl.callbacks.ModelCheckpoint( - dirpath=config["model_save_folder_path"], - save_top_k=-1, - save_weights_only=config["save_weights_only"], - every_n_train_steps=config["every_n_train_steps"], + + architecture_params = set(model_params.keys()) - set( + loaded_model_params.keys() ) - ] - else: - callbacks = None - - trainer = pl.Trainer( - accelerator="auto", - auto_select_gpus=True, - callbacks=callbacks, - devices=_get_devices(config["no_gpu"]), - enable_checkpointing=config["save_model"], - logger=config["logger"], - max_epochs=config["max_epochs"], - num_sanity_val_steps=config["num_sanity_val_steps"], - strategy=_get_strategy(), - val_check_interval=config["every_n_train_steps"], - ) - # Train the model. - trainer.fit( - model, train_loader.train_dataloader(), val_loader.val_dataloader() - ) - # Clean up temporary files. - tmp_dir.cleanup() + for param in architecture_params: + if model_params[param] != self.model.hparams[param]: + warnings.warn( + f"Mismatching {param} parameter in " + f"model checkpoint ({self.model.hparams[param]}) " + f"vs config file ({model_params[param]}); " + "using the checkpoint." + ) + except RuntimeError: + # This only doesn't work if the weights are from an older version + try: + self.model = Spec2Pep.load_from_checkpoint( + self.model_filename, + map_location=device, + **model_params, + ) + except RuntimeError: + raise RuntimeError( + "Weights file incompatible with the current version of " + "Casanovo. " + ) + + def initialize_data_module( + self, + train_index: Optional[AnnotatedSpectrumIndex] = None, + valid_index: Optional[AnnotatedSpectrumIndex] = None, + test_index: ( + Optional[Union[AnnotatedSpectrumIndex, SpectrumIndex]] + ) = None, + ) -> None: + """Initialize the data module + + Parameters + ---------- + train_index : AnnotatedSpectrumIndex, optional + A spectrum index for model training. + valid_index : AnnotatedSpectrumIndex, optional + A spectrum index for validation. + test_index : AnnotatedSpectrumIndex or SpectrumIndex, optional + A spectrum index for evaluation or inference. + """ + try: + n_devices = self.trainer.num_devices + train_bs = self.config.train_batch_size // n_devices + eval_bs = self.config.predict_batch_size // n_devices + except AttributeError: + raise RuntimeError("Please use `initialize_trainer()` first.") + + self.loaders = DeNovoDataModule( + train_index=train_index, + valid_index=valid_index, + test_index=test_index, + min_mz=self.config.min_mz, + max_mz=self.config.max_mz, + min_intensity=self.config.min_intensity, + remove_precursor_tol=self.config.remove_precursor_tol, + n_workers=self.config.n_workers, + train_batch_size=train_bs, + eval_batch_size=eval_bs, + ) + + def _get_index( + self, + peak_path: Iterable[str], + annotated: bool, + msg: str = "", + ) -> Union[SpectrumIndex, AnnotatedSpectrumIndex]: + """Get the spectrum index. + + If the file is a SpectrumIndex, only one is allowed. Otherwise multiple + may be specified. + + Parameters + ---------- + peak_path : Iterable[str] + The peak files/directories to check. + annotated : bool + Are the spectra expected to be annotated? + msg : str, optional + A string to insert into the error message. + + Returns + ------- + SpectrumIndex or AnnotatedSpectrumIndex + The spectrum index for training, evaluation, or inference. + """ + ext = (".mgf", ".h5", ".hdf5") + if not annotated: + ext += (".mzml", ".mzxml") + + msg = msg.strip() + filenames = _get_peak_filenames(peak_path, ext) + if not filenames: + not_found_err = f"Cound not find {msg} peak files" + logger.error(not_found_err + " from %s", peak_path) + raise FileNotFoundError(not_found_err) + + is_index = any([Path(f).suffix in (".h5", ".hdf5") for f in filenames]) + if is_index: + if len(filenames) > 1: + h5_err = f"Multiple {msg} HDF5 spectrum indexes specified" + logger.error(h5_err) + raise ValueError(h5_err) + + index_fname, filenames = filenames[0], None + else: + index_fname = Path(self.tmp_dir.name) / f"{uuid.uuid4().hex}.hdf5" + + Index = AnnotatedSpectrumIndex if annotated else SpectrumIndex + valid_charge = np.arange(1, self.config.max_charge + 1) + return Index(index_fname, filenames, valid_charge=valid_charge) + + def _get_strategy(self) -> Union[str, DDPStrategy]: + """Get the strategy for the Trainer. + + The DDP strategy works best when multiple GPUs are used. It can work + for CPU-only, but definitely fails using MPS (the Apple Silicon chip) + due to Gloo. + + Returns + ------- + Union[str, DDPStrategy] + The strategy parameter for the Trainer. + + """ + if self.config.accelerator in ("cpu", "mps"): + return "auto" + elif self.config.devices == 1: + return "auto" + elif torch.cuda.device_count() > 1: + return DDPStrategy(find_unused_parameters=False, static_graph=True) + else: + return "auto" def _get_peak_filenames( - path: str, supported_ext: Iterable[str] = (".mgf",) + paths: Iterable[str], supported_ext: Iterable[str] ) -> List[str]: """ Get all matching peak file names from the path pattern. @@ -333,65 +426,22 @@ def _get_peak_filenames( Parameters ---------- - path : str - The path pattern. + paths : Iterable[str] + The path pattern(s). supported_ext : Iterable[str] - Extensions of supported peak file formats. Default: MGF. + Extensions of supported peak file formats. Returns ------- List[str] The peak file names matching the path pattern. """ - path = os.path.expanduser(path) - path = os.path.expandvars(path) - return [ - os.path.abspath(fn) - for fn in glob.glob(path, recursive=True) - if os.path.splitext(fn.lower())[1] in supported_ext - ] - - -def _get_strategy() -> Optional[DDPStrategy]: - """ - Get the strategy for the Trainer. - - The DDP strategy works best when multiple GPUs are used. It can work for - CPU-only, but definitely fails using MPS (the Apple Silicon chip) due to - Gloo. - - Returns - ------- - Optional[DDPStrategy] - The strategy parameter for the Trainer. - """ - if torch.cuda.device_count() > 1: - return DDPStrategy(find_unused_parameters=False, static_graph=True) - - return None - - -def _get_devices(no_gpu: bool) -> Union[int, str]: - """ - Get the number of GPUs/CPUs for the Trainer to use. - - Parameters - ---------- - no_gpu : bool - If true, disable all GPU usage. - - Returns - ------- - Union[int, str] - The number of GPUs/CPUs to use, or "auto" to let PyTorch Lightning - determine the appropriate number of devices. - """ - if not no_gpu and any( - operator.attrgetter(device + ".is_available")(torch)() - for device in ("cuda",) - ): - return -1 - elif not (n_workers := utils.n_workers()): - return "auto" - else: - return n_workers + found_files = set() + for path in paths: + path = os.path.expanduser(path) + path = os.path.expandvars(path) + for fname in glob.glob(path, recursive=True): + if Path(fname).suffix.lower() in supported_ext: + found_files.add(fname) + + return sorted(list(found_files)) diff --git a/casanovo/utils.py b/casanovo/utils.py index cca67747..b497ac12 100644 --- a/casanovo/utils.py +++ b/casanovo/utils.py @@ -1,4 +1,5 @@ """Small utility functions""" +import logging import os import platform import re @@ -8,6 +9,9 @@ import torch +logger = logging.getLogger("casanovo") + + def n_workers() -> int: """ Get the number of workers to use for data loading. @@ -26,6 +30,10 @@ def n_workers() -> int: """ # Windows or MacOS: no multiprocessing. if platform.system() in ["Windows", "Darwin"]: + logger.warning( + "Dataloader multiprocessing is currently not supported on Windows " + "or MacOS; using only a single thread." + ) return 0 # Linux: scale the number of workers by the number of GPUs (if present). try: diff --git a/docs/faq.md b/docs/faq.md index ec46bdc3..85d63cd3 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -60,8 +60,8 @@ Using the filename (column "filename") you can then retrieve the corresponding p By default, Casanovo saves a snapshot of the model weights after every 50,000 training steps. Note that the number of samples that are processed during a single training step depends on the batch size. -Therefore, when using the default training batch size of 32, this correspond to saving a model snapshot after every 1.6 million training samples. -You can optionally modify the snapshot frequency in the [config file](https://github.com/Noble-Lab/casanovo/blob/main/casanovo/config.yaml) (parameter `every_n_train_steps`), depending on your dataset size. +Therefore, when using the default training batch size of 32, this corresponds to saving a model snapshot after every 1.6 million training samples. +You can optionally modify the snapshot (and validation) frequency in the [config file](https://github.com/Noble-Lab/casanovo/blob/main/casanovo/config.yaml) (parameter `val_check_interval`), depending on your dataset size. Note that taking very frequent model snapshots will result in somewhat slower training time because Casanovo will evaluate its performance on the validation data for every snapshot. When saving a model snapshot, Casanovo will use the validation data to compute performance measures (training loss, validation loss, amino acid precision, and peptide precision) and print this information to the console and log file. diff --git a/docs/getting_started.md b/docs/getting_started.md index 170cba66..94f0a308 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -16,8 +16,8 @@ Once you have conda installed, you can use this helpful [cheat sheet](https://do ### Create a conda environment -Fist, open the terminal (MacOS and Linux) or the Anaconda Prompt (Windows). -All of the commands that follow should be entered this terminal or Anaconda Prompt window---that is, your *shell*. +First, open the terminal (MacOS and Linux) or the Anaconda Prompt (Windows). +All of the commands that follow should be entered into this terminal or Anaconda Prompt window---that is, your *shell*. To create a new conda environment for Casanovo, run the following: ```sh @@ -58,19 +58,29 @@ After installation, test that it was successful by viewing the Casanovo command ```sh casanovo --help ``` +![`casanovo --help`](images/help.svg) -All auxiliary data, model, and training-related parameters can be specified in a user created `.yaml` configuration file. -See [`casanovo/config.yaml`](https://github.com/Noble-Lab/casanovo/blob/main/casanovo/config.yaml) for the default configuration that was used to obtain the reported results. When running Casanovo in eval or denovo mode, you can change some of the parameters in this file, indicated with "(I)" in the file. You should not change other parameters unless you are training a new Casanovo model. + +All auxiliary data, model, and training-related parameters can be specified in a YAML configuration file. +To generate a YAML file containing the current Casanovo defaults, run: +```sh +casanovo configure +``` +![`casanovo configure --help`](images/configure-help.svg) + +When using Casanovo to sequence peptides from mass spectra or evaluate a previous model's performance, you can change some of the parameters in this file, indicated with "(I)" in the file. +The other parameters will not have an effect unless you are training a new Casanovo model. ### Download model weights -When running Casanovo in `denovo` or `eval` mode, Casanovo needs compatible pretrained model weights to make predictions. -Our model weights are uploaded with new Casanovo versions on the [Releases page](https://github.com/Noble-Lab/casanovo/releases) under the "Assets" for each release (file extension: .ckpt). -The model file can then be specified using the `--model` command-line parameter when executing Casanovo. -To assist users, if no model file is specified Casanovo will try to download and use a compatible model file automatically. +Using Casanovo to sequence peptides from new mass spectra, Casanovo needs compatible pretrained model weights to make its predictions. +By default, Casanovo will try to download the latest compatible model weights from GitHub when it is run. -Not all releases might have a model file included on the [Releases page](https://github.com/Noble-Lab/casanovo/releases), in which case model weights for alternative releases with the same major version number can be used. +However, our model weights are uploaded with new Casanovo versions on the [Releases page](https://github.com/Noble-Lab/casanovo/releases) under the "Assets" for each release (file extension: `.ckpt`). +This model file or a custom one can then be specified using the `--model` command-line parameter when executing Casanovo. + +Not all releases will have a model file included on the [Releases page](https://github.com/Noble-Lab/casanovo/releases), in which case model weights for alternative releases with the same major version number can be used. The most recent model weights for Casanovo version 3.x are currently provided under [Casanovo v3.0.0](https://github.com/Noble-Lab/casanovo/releases/tag/v3.0.0): - `casanovo_massivekb.ckpt`: Default Casanovo weights to use when analyzing tryptic data. These weights will be downloaded automatically if no weights are explicitly specified. @@ -80,45 +90,41 @@ The most recent model weights for Casanovo version 3.x are currently provided un ```{note} We recommend a Linux system with a dedicated GPU to achieve optimal runtime performance. -Notably, Casanovo is restricted to single-threaded execution only on Windows and MacOS. ``` -> **Warning** -> Casanovo can currently crash if no GPU is available. -> We are actively trying to fix this known issue. - ### Sequence new mass spectra -To sequence your own mass spectra with Casanovo, use the `denovo` mode: +To sequence your own mass spectra with Casanovo, use the `casanovo sequence` command: ```sh -casanovo --mode=denovo --peak_path=path/to/predict/spectra.mgf --output=path/to/output +casanovo sequence -o results.mztab spectra.mgf ``` +![`casanovo sequence --help`](images/sequence-help.svg) Casanovo can predict peptide sequences for MS/MS spectra in mzML, mzXML, and MGF files. This will write peptide predictions for the given MS/MS spectra to the specified output file in mzTab format. -> **Warning** -> If you are running inference with Casanovo on a system that has multiple GPUs, it is necessary to restrict Casanovo to (maximum) a single GPU. -> For example, for CUDA-capable GPUs, GPU visibility can be controlled by setting the `CUDA_VISIBLE_DEVICES` shell variable. - ### Evaluate *de novo* sequencing performance -To evaluate _de novo_ sequencing performance based on known mass spectrum annotations, run: +To evaluate _de novo_ sequencing performance based on known mass spectrum annotations, use the `casanovo evaluate` command: ```sh -casanovo --mode=eval --peak_path=path/to/test/annotated_spectra.mgf +casanovo evaluate annotated_spectra.mgf ``` +![`casanovo evaluate --help`](images/evaluate-help.svg) + -To evaluate the peptide predictions, ground truth peptide labels must to be provided as an annotated MGF file where the peptide sequence is denoted in the `SEQ` field. +To evaluate the peptide predictions, ground truth peptide labels must to be provided as an annotated MGF file where the peptide sequence is denoted in the `SEQ` field. +Compatible MGF files are available from [MassIVE-KB](https://massive.ucsd.edu/ProteoSAFe/static/massive-kb-libraries.jsp). ### Train a new model To train a model from scratch, run: ```sh -casanovo --mode=train --peak_path=path/to/train/annotated_spectra.mgf --peak_path_val=path/to/validation/annotated_spectra.mgf +casanovo train --validation_peak_path validation_spectra.mgf training_spectra.mgf ``` +![`casanovo train --help`](images/train-help.svg) Training and validation MS/MS data need to be provided as annotated MGF files, where the peptide sequence is denoted in the `SEQ` field. @@ -126,7 +132,7 @@ If a training is continued for a previously trained model, specify the starting ## Try Casanovo on a small example -Here, we demonstrate how to use Casanovo using a small collection of mass spectra in an MGF file (~100 MS/MS spectra). +Let's use Casanovo to sequence peptides from a small collection of mass spectra in an MGF file (~100 MS/MS spectra). The example MGF file is available at [`sample_data/sample_preprocessed_spectra.mgf`](https://github.com/Noble-Lab/casanovo/blob/main/sample_data/sample_preprocessed_spectra.mgf). To obtain *de novo* sequencing predictions for these spectra: @@ -135,7 +141,7 @@ To obtain *de novo* sequencing predictions for these spectra: 3. Ensure your Casanovo conda environment is activated by typing `conda activate casanovo_env`. (If you named your environment differently, type in that name instead.) 4. Sequence the mass spectra with Casanovo, replacing `[PATH_TO]` with the path to the example MGF file that you downloaded: ```sh -casanovo --mode=denovo --peak_path=[PATH_TO]/sample_preprocessed_spectra.mgf +casanovo sequence [PATH_TO]/sample_preprocessed_spectra.mgf ``` ```{note} diff --git a/docs/images/configure-help.svg b/docs/images/configure-help.svg new file mode 100644 index 00000000..d5dd7aa8 --- /dev/null +++ b/docs/images/configure-help.svg @@ -0,0 +1,108 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + $ casanovo configure --help + +Usage:casanovo configure [OPTIONS]                                             + + Generate a Casanovo configuration file to customize.                            + The casanovo configuration file is in the YAML format.                          + +╭─ Options ────────────────────────────────────────────────────────────────────╮ +--output-oFILE  The output configuration file.                            +--help-h  Show this message and exit.                               +╰──────────────────────────────────────────────────────────────────────────────╯ + + + + + diff --git a/docs/images/evaluate-help.svg b/docs/images/evaluate-help.svg new file mode 100644 index 00000000..e220664b --- /dev/null +++ b/docs/images/evaluate-help.svg @@ -0,0 +1,167 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + $ casanovo evaluate --help + +Usage:casanovo evaluate [OPTIONSANNOTATED_PEAK_PATH...                       + + Evaluate de novo peptide sequencing performance.                                + ANNOTATED_PEAK_PATH must be one or more annoated MGF files, such as those       + provided by MassIVE-KB.                                                         + +╭─ Arguments ──────────────────────────────────────────────────────────────────╮ +*  ANNOTATED_PEAK_PATH    FILE[required] +╰──────────────────────────────────────────────────────────────────────────────╯ +╭─ Options ────────────────────────────────────────────────────────────────────╮ +--model-mFILE                        The model weights (.ckpt file).  +                                              If not provided, Casanovo will   +                                              try to download the latest       +                                              release.                         +--output-oFILE                        The mzTab file to which results  +                                              will be written.                 +--config-cFILE                        The YAML configuration file      +                                              overriding the default options.  +--verbosity-v[debug|info|warning|error]  Set the verbosity of console     +                                              logging messages. Log files are  +                                              always set to 'debug'.           +--help-h  Show this message and exit.      +╰──────────────────────────────────────────────────────────────────────────────╯ + + + + + diff --git a/docs/images/help.svg b/docs/images/help.svg new file mode 100644 index 00000000..42180a3f --- /dev/null +++ b/docs/images/help.svg @@ -0,0 +1,201 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + $ casanovo --help + +Usage:casanovo [OPTIONSCOMMAND [ARGS]...                                     + + ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓  + ┃                                  Casanovo                                  ┃  + ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛  + Casanovo de novo sequences peptides from tandem mass spectra using a            + Transformer model. Casanovo currently supports mzML, mzXML, and MGF files for   + de novo sequencing and annotated MGF files, such as those from MassIVE-KB, for  + training new models.                                                            + + Links:                                                                          + + • Documentation: https://casanovo.readthedocs.io + • Official code repository: https://github.com/Noble-Lab/casanovo + + If you use Casanovo in your work, please cite:                                  + + • Yilmaz, M., Fondrie, W. E., Bittremieux, W., Oh, S. & Noble, W. S. De novo   +mass spectrometry peptide sequencing with a transformer model. Proceedings   +of the 39th International Conference on Machine Learning - ICML '22 (2022)   +doi:10.1101/2022.02.07.479481.                                               + +╭─ Options ────────────────────────────────────────────────────────────────────╮ +--help-h    Show this message and exit.                                     +╰──────────────────────────────────────────────────────────────────────────────╯ +╭─ Commands ───────────────────────────────────────────────────────────────────╮ +configure Generate a Casanovo configuration file to customize.               +evaluate  Evaluate de novo peptide sequencing performance.                   +sequence  De novo sequence peptides from tandem mass spectra.                +train     Train a Casanovo model on your own data.                           +version   Get the Casanovo version information                               +╰──────────────────────────────────────────────────────────────────────────────╯ + + + + + diff --git a/docs/images/sequence-help.svg b/docs/images/sequence-help.svg new file mode 100644 index 00000000..d493e2b2 --- /dev/null +++ b/docs/images/sequence-help.svg @@ -0,0 +1,167 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + $ casanovo sequence --help + +Usage:casanovo sequence [OPTIONSPEAK_PATH...                                 + + De novo sequence peptides from tandem mass spectra.                             + PEAK_PATH must be one or more mzMl, mzXML, or MGF files from which to sequence  + peptides.                                                                       + +╭─ Arguments ──────────────────────────────────────────────────────────────────╮ +*  PEAK_PATH    FILE[required] +╰──────────────────────────────────────────────────────────────────────────────╯ +╭─ Options ────────────────────────────────────────────────────────────────────╮ +--model-mFILE                        The model weights (.ckpt file).  +                                              If not provided, Casanovo will   +                                              try to download the latest       +                                              release.                         +--output-oFILE                        The mzTab file to which results  +                                              will be written.                 +--config-cFILE                        The YAML configuration file      +                                              overriding the default options.  +--verbosity-v[debug|info|warning|error]  Set the verbosity of console     +                                              logging messages. Log files are  +                                              always set to 'debug'.           +--help-h  Show this message and exit.      +╰──────────────────────────────────────────────────────────────────────────────╯ + + + + + diff --git a/docs/images/train-help.svg b/docs/images/train-help.svg new file mode 100644 index 00000000..82c30122 --- /dev/null +++ b/docs/images/train-help.svg @@ -0,0 +1,219 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + $ casanovo train --help + +Usage:casanovo train [OPTIONSTRAIN_PEAK_PATH...                              + + Train a Casanovo model on your own data.                                        + TRAIN_PEAK_PATH must be one or more annoated MGF files, such as those provided  + by MassIVE-KB, from which to train a new Casnovo model.                         + +╭─ Arguments ──────────────────────────────────────────────────────────────────╮ +*  TRAIN_PEAK_PATH    FILE[required] +╰──────────────────────────────────────────────────────────────────────────────╯ +╭─ Options ────────────────────────────────────────────────────────────────────╮ +*--validation_peak_pa…-pFILE                    An annotated MGF file   +                                                       for validation, like    +                                                       from MassIVE-KB. Use    +                                                       this option multiple    +                                                       times to specify        +                                                       multiple files.         +[required]             +--model-mFILE                    The model weights       +                                                       (.ckpt file). If not    +                                                       provided, Casanovo      +                                                       will try to download    +                                                       the latest release.     +--output-oFILE                    The mzTab file to       +                                                       which results will be   +                                                       written.                +--config-cFILE                    The YAML configuration  +                                                       file overriding the     +                                                       default options.        +--verbosity-v[debug|info|warning|er  Set the verbosity of    +ror]  console logging         +                                                       messages. Log files     +                                                       are always set to       +                                                       'debug'.                +--help-h  Show this message and   +                                                       exit.                   +╰──────────────────────────────────────────────────────────────────────────────╯ + + + + + diff --git a/pyproject.toml b/pyproject.toml index db0fbe35..551954ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,20 +20,21 @@ classifiers = [ requires-python = ">=3.8" dependencies = [ "appdirs", + "lightning>=2.0", "click", - "depthcharge-ms>=0.1.0,<0.2.0", + "depthcharge-ms>=0.2.3,<0.3.0", "natsort", "numpy", "pandas", "psutil", "PyGithub", - "pytorch-lightning>=1.7,<2.0", "PyYAML", "requests", + "rich-click>=1.6.1", "scikit-learn", "spectrum_utils", "tensorboard", - "torch>=1.9", + "torch>=2.1", "tqdm", ] dynamic = ["version"] diff --git a/tests/conftest.py b/tests/conftest.py index 574aed8f..a690bd8a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import numpy as np import psims import pytest +import yaml from pyteomics.mass import calculate_mass @@ -180,3 +181,82 @@ def _create_mzml(peptides, mzml_file, random_state=42): ) return mzml_file + + +@pytest.fixture +def tiny_config(tmp_path): + """A config file for a tiny model.""" + cfg = { + "n_head": 2, + "dim_feedforward": 10, + "n_layers": 1, + "train_label_smoothing": 0.01, + "warmup_iters": 1, + "max_iters": 1, + "max_epochs": 20, + "val_check_interval": 1, + "model_save_folder_path": str(tmp_path), + "accelerator": "cpu", + "precursor_mass_tol": 5, + "isotope_error_range": [0, 1], + "min_peptide_len": 6, + "predict_batch_size": 1024, + "n_beams": 1, + "top_match": 1, + "devices": None, + "random_seed": 454, + "n_log": 1, + "tb_summarywriter": None, + "save_top_k": 5, + "n_peaks": 150, + "min_mz": 50.0, + "max_mz": 2500.0, + "min_intensity": 0.01, + "remove_precursor_tol": 2.0, + "max_charge": 10, + "dim_model": 512, + "dropout": 0.0, + "dim_intensity": None, + "max_length": 100, + "learning_rate": 5e-4, + "weight_decay": 1e-5, + "train_batch_size": 32, + "num_sanity_val_steps": 0, + "train_from_scratch": True, + "calculate_precision": False, + "residues": { + "G": 57.021464, + "A": 71.037114, + "S": 87.032028, + "P": 97.052764, + "V": 99.068414, + "T": 101.047670, + "C+57.021": 160.030649, + "L": 113.084064, + "I": 113.084064, + "N": 114.042927, + "D": 115.026943, + "Q": 128.058578, + "K": 128.094963, + "E": 129.042593, + "M": 131.040485, + "H": 137.058912, + "F": 147.068414, + "R": 156.101111, + "Y": 163.063329, + "W": 186.079313, + "M+15.995": 147.035400, + "N+0.984": 115.026943, + "Q+0.984": 129.042594, + "+42.011": 42.010565, + "+43.006": 43.005814, + "-17.027": -17.026549, + "+43.006-17.027": 25.980265, + }, + } + + cfg_file = tmp_path / "config.yml" + with cfg_file.open("w+") as out_file: + yaml.dump(cfg, out_file) + + return cfg_file diff --git a/tests/test_integration.py b/tests/test_integration.py index 7d63f5c1..e5d4b285 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,26 +1,71 @@ +import functools +from pathlib import Path + import pyteomics.mztab +from click.testing import CliRunner from casanovo import casanovo -def test_denovo(mgf_small, mzml_small, tmp_path, monkeypatch): +def test_train_and_run( + mgf_small, mzml_small, tiny_config, tmp_path, monkeypatch +): # We can use this to explicitly test different versions. monkeypatch.setattr(casanovo, "__version__", "3.0.1") - # Predict on small files (MGF and mzML) and verify that the output mzTab - # file exists. - output_filename = tmp_path / "test.mztab" - casanovo.main( - [ - "--mode", - "denovo", - "--peak_path", - str(mgf_small).replace(".mgf", ".m*"), - "--output", - str(output_filename), - ], - standalone_mode=False, + # Run a command: + run = functools.partial( + CliRunner().invoke, casanovo.main, catch_exceptions=False ) + + # Train a tiny model: + train_args = [ + "train", + "--validation_peak_path", + str(mgf_small), + "--config", + tiny_config, + "--output", + str(tmp_path / "train"), + str(mgf_small), # The training files. + ] + + result = run(train_args) + model_file = tmp_path / "epoch=19-step=20.ckpt" + assert result.exit_code == 0 + assert model_file.exists() + + # Try evaluating: + eval_args = [ + "evaluate", + "--model", + str(model_file), + "--config", + str(tiny_config), + "--output", + str(tmp_path / "eval"), + str(mgf_small), + ] + + result = run(eval_args) + assert result.exit_code == 0 + + # Finally try predicting: + output_filename = tmp_path / "test.mztab" + predict_args = [ + "sequence", + "--model", + str(model_file), + "--config", + tiny_config, + "--output", + str(output_filename), + str(mgf_small), + str(mzml_small), + ] + + result = run(predict_args) + assert result.exit_code == 0 assert output_filename.is_file() mztab = pyteomics.mztab.MzTab(str(output_filename)) @@ -29,8 +74,8 @@ def test_denovo(mgf_small, mzml_small, tmp_path, monkeypatch): assert f"ms_run[{i}]-location" in mztab.metadata assert mztab.metadata[f"ms_run[{i}]-location"].endswith(filename) - # Verify that the spectrum predictions are correct and indexed according to - # the peak input file type. + # Verify that the spectrum predictions are correct + # and indexed according to the peak input file type. psms = mztab.spectrum_match_table assert psms.loc[1, "sequence"] == "LESLLEK" assert psms.loc[1, "spectra_ref"] == "ms_run[1]:index=0" @@ -40,3 +85,20 @@ def test_denovo(mgf_small, mzml_small, tmp_path, monkeypatch): assert psms.loc[3, "spectra_ref"] == "ms_run[2]:scan=17" assert psms.loc[4, "sequence"] == "PEPTLDEK" assert psms.loc[4, "spectra_ref"] == "ms_run[2]:scan=111" + + +def test_auxilliary_cli(tmp_path, monkeypatch): + """Test the secondary CLI commands""" + run = functools.partial( + CliRunner().invoke, casanovo.main, catch_exceptions=False + ) + + monkeypatch.chdir(tmp_path) + run("configure") + assert Path("casanovo.yaml").exists() + + run(["configure", "-o", "test.yaml"]) + assert Path("test.yaml").exists() + + res = run("version") + assert res.output diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 89b37e42..7a0d7a26 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -1,4 +1,7 @@ """Test configuration loading""" +import pytest +import yaml + from casanovo.config import Config @@ -7,31 +10,29 @@ def test_default(): config = Config() assert config.random_seed == 454 assert config["random_seed"] == 454 - assert not config.no_gpu + assert config.accelerator == "auto" assert config.file == "default" -def test_override(tmp_path): - """Test overriding the default""" - yml = tmp_path / "test.yml" - with yml.open("w+") as f_out: - f_out.write( - """random_seed: 42 -top_match: 3 -residues: - W: 1 - O: 2 - U: 3 - T: 4 -""" - ) - - config = Config(yml) - assert config.random_seed == 42 - assert config["random_seed"] == 42 - assert not config.no_gpu - assert config.top_match == 3 - assert len(config.residues) == 4 - for i, residue in enumerate("WOUT", 1): - assert config["residues"][residue] == i - assert config.file == str(yml) +def test_override(tmp_path, tiny_config): + # Test expected config option is missing. + filename = str(tmp_path / "config_missing.yml") + with open(tiny_config, "r") as f_in, open(filename, "w") as f_out: + cfg = yaml.safe_load(f_in) + # Remove config option. + del cfg["random_seed"] + yaml.safe_dump(cfg, f_out) + + with pytest.raises(KeyError): + Config(filename) + + # Test invalid config option is present. + filename = str(tmp_path / "config_invalid.yml") + with open(tiny_config, "r") as f_in, open(filename, "w") as f_out: + cfg = yaml.safe_load(f_in) + # Insert invalid config option. + cfg["random_seed_"] = 354 + yaml.safe_dump(cfg, f_out) + + with pytest.raises(KeyError): + Config(filename) diff --git a/tests/unit_tests/test_runner.py b/tests/unit_tests/test_runner.py new file mode 100644 index 00000000..6be91831 --- /dev/null +++ b/tests/unit_tests/test_runner.py @@ -0,0 +1,118 @@ +"""Unit tests specifically for the model_runner module.""" +import pytest +import torch + +from casanovo.config import Config +from casanovo.denovo.model_runner import ModelRunner + + +def test_initialize_model(tmp_path): + """Test that""" + config = Config() + config.train_from_scratch = False + ModelRunner(config=config).initialize_model(train=True) + + with pytest.raises(ValueError): + ModelRunner(config=config).initialize_model(train=False) + + with pytest.raises(FileNotFoundError): + runner = ModelRunner(config=config, model_filename="blah") + runner.initialize_model(train=True) + + with pytest.raises(FileNotFoundError): + runner = ModelRunner(config=config, model_filename="blah") + runner.initialize_model(train=False) + + # This should work now: + config.train_from_scratch = True + runner = ModelRunner(config=config, model_filename="blah") + runner.initialize_model(train=True) + + # But this should still fail: + with pytest.raises(FileNotFoundError): + runner = ModelRunner(config=config, model_filename="blah") + runner.initialize_model(train=False) + + # If the model initialization throws and EOFError, then the Spec2Pep model + # has tried to load the weights: + weights = tmp_path / "blah" + weights.touch() + with pytest.raises(EOFError): + runner = ModelRunner(config=config, model_filename=str(weights)) + runner.initialize_model(train=False) + + +def test_save_and_load_weights(tmp_path, mgf_small, tiny_config): + """Test saving aloading weights""" + config = Config(tiny_config) + config.max_epochs = 1 + config.n_layers = 1 + ckpt = tmp_path / "test.ckpt" + + with ModelRunner(config=config) as runner: + runner.train([mgf_small], [mgf_small]) + runner.trainer.save_checkpoint(ckpt) + + # Try changing model arch: + other_config = Config(tiny_config) + other_config.n_layers = 50 # lol + other_config.n_beams = 12 + other_config.max_iters = 2 + with torch.device("meta"): + # Now load the weights into a new model + # The device should be meta for all the weights. + runner = ModelRunner(config=other_config, model_filename=str(ckpt)) + runner.initialize_model(train=False) + + obs_layers = runner.model.encoder.transformer_encoder.num_layers + assert obs_layers == 1 # Match the original arch. + assert runner.model.n_beams == 12 # Match the config + assert runner.model.max_iters == 2 # Match the config + assert next(runner.model.parameters()).device == torch.device("meta") + + # If the Trainer correctly moves the weights to the accelerator, + # then it should fail if the weights are on the "meta" device. + with torch.device("meta"): + with ModelRunner(other_config, model_filename=str(ckpt)) as runner: + with pytest.raises(NotImplementedError) as err: + runner.evaluate([mgf_small]) + + assert "meta tensor; no data!" in str(err.value) + + # Try without arch: + ckpt_data = torch.load(ckpt) + del ckpt_data["hyper_parameters"] + torch.save(ckpt_data, ckpt) + + # Shouldn't work: + with ModelRunner(other_config, model_filename=str(ckpt)) as runner: + with pytest.raises(RuntimeError): + runner.evaluate([mgf_small]) + + # Should work: + with ModelRunner(config=config, model_filename=str(ckpt)) as runner: + runner.evaluate([mgf_small]) + + +def test_calculate_precision(tmp_path, mgf_small, tiny_config): + """Test that this parameter is working correctly.""" + config = Config(tiny_config) + config.n_layers = 1 + config.max_epochs = 1 + config.calculate_precision = False + config.tb_summarywriter = str(tmp_path) + + runner = ModelRunner(config=config) + with runner: + runner.train([mgf_small], [mgf_small]) + + assert "valid_aa_precision" not in runner.model.history.columns + assert "valid_pep_precision" not in runner.model.history.columns + + config.calculate_precision = True + runner = ModelRunner(config=config) + with runner: + runner.train([mgf_small], [mgf_small]) + + assert "valid_aa_precision" in runner.model.history.columns + assert "valid_pep_precision" in runner.model.history.columns diff --git a/tests/unit_tests/test_unit.py b/tests/unit_tests/test_unit.py index bc0509bd..6b840d20 100644 --- a/tests/unit_tests/test_unit.py +++ b/tests/unit_tests/test_unit.py @@ -69,6 +69,7 @@ def test_split_version(): assert version == ("3", "0", "1") +@pytest.mark.skip(reason="Hit rate limit during CI/CD") def test_get_model_weights(monkeypatch): """ Test that model weights can be downloaded from GitHub or used from the @@ -341,6 +342,34 @@ def test_beam_search_decode(): assert torch.equal(pred_cache[0][0][-1], torch.tensor([4, 14, 4, 13])) + # Test _get_topk_beams(). + step = 1 + scores = torch.full( + size=(batch, length, vocab, beam), fill_value=torch.nan + ) + scores = einops.rearrange(scores, "B L V S -> (B S) L V") + tokens = torch.zeros(batch * beam, length, dtype=torch.int64) + tokens[0, 0] = 4 + scores[0, step, :] = 0 + scores[0, step, 14] = torch.tensor([1]) + test_finished_beams = torch.tensor([False]) + + new_tokens, new_scores = model._get_topk_beams( + tokens, scores, test_finished_beams, batch, step + ) + + expected_tokens = torch.tensor( + [ + [4, 14], + ] + ) + + expected_scores = torch.zeros(beam, vocab) + expected_scores[:, 14] = torch.tensor([1]) + + assert torch.equal(new_scores[:, step, :], expected_scores) + assert torch.equal(new_tokens[:, : step + 1], expected_tokens) + # Test _finish_beams() for tokens with a negative mass. model = Spec2Pep(n_beams=2, residues="massivekb") beam = model.n_beams # S @@ -485,15 +514,28 @@ def test_spectrum_id_mzml(mzml_small, tmp_path): def test_train_val_step_functions(): """Test train and validation step functions operating on batches.""" - model = Spec2Pep(n_beams=1, residues="massivekb", min_peptide_len=4) + model = Spec2Pep( + n_beams=1, + residues="massivekb", + min_peptide_len=4, + train_label_smoothing=0.1, + ) spectra = torch.zeros(1, 5, 2) precursors = torch.tensor([[469.25364, 2.0, 235.63410]]) peptides = ["PEPK"] batch = (spectra, precursors, peptides) + train_step_loss = model.training_step(batch) + val_step_loss = model.validation_step(batch) + # Check if valid loss value returned - assert model.training_step(batch) > 0 - assert model.validation_step(batch) > 0 + assert train_step_loss > 0 + assert val_step_loss > 0 + + # Check if smoothing is applied in training and not in validation + assert model.celoss.label_smoothing == 0.1 + assert model.val_celoss.label_smoothing == 0 + assert val_step_loss != train_step_loss def test_run_map(mgf_small):