diff --git a/casanovo/casanovo.py b/casanovo/casanovo.py index 72b02057..0be7a784 100644 --- a/casanovo/casanovo.py +++ b/casanovo/casanovo.py @@ -15,10 +15,8 @@ import click import github import requests -import torch import tqdm -import yaml -from pytorch_lightning.lite import LightningLite +from lightning import Fabric from . import __version__ from . import utils @@ -127,7 +125,7 @@ def main( # Read parameters from the config file. config = Config(config) - LightningLite.seed_everything(seed=config["random_seed"], workers=True) + Fabric.seed_everything(seed=config["random_seed"], workers=True) # Download model weights if these were not specified (except when training). if model is None and mode != "train": diff --git a/casanovo/config.py b/casanovo/config.py index 4dc93c26..352d6743 100644 --- a/casanovo/config.py +++ b/casanovo/config.py @@ -50,6 +50,7 @@ class Config: dropout=float, dim_intensity=int, max_length=int, + residues=dict, # note, this key is special-cased and type is ignored n_log=int, tb_summarywriter=str, warmup_iters=int, diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index 8b3ae3e0..32a50b31 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -724,8 +724,8 @@ def training_step( pred = pred[:, :-1, :].reshape(-1, self.decoder.vocab_size + 1) loss = self.celoss(pred, truth.flatten()) self.log( - "CELoss", - {mode: loss.detach()}, + f"CELoss/{mode}", + loss.detach(), on_step=False, on_epoch=True, sync_dist=True, @@ -765,13 +765,11 @@ def validation_step( ) log_args = dict(on_step=False, on_epoch=True, sync_dist=True) self.log( - "Peptide precision at coverage=1", - {"valid": pep_precision}, + "Peptide precision at coverage=1/valid", + pep_precision, **log_args, ) - self.log( - "AA precision at coverage=1", {"valid": aa_precision}, **log_args - ) + self.log("AA precision at coverage=1/valid", aa_precision, **log_args) return loss @@ -824,7 +822,7 @@ def on_train_epoch_end(self) -> None: """ Log the training loss at the end of each epoch. """ - train_loss = self.trainer.callback_metrics["CELoss"]["train"].detach() + train_loss = self.trainer.callback_metrics["CELoss/train"].detach() metrics = { "step": self.trainer.global_step, "train": train_loss, @@ -839,20 +837,18 @@ def on_validation_epoch_end(self) -> None: callback_metrics = self.trainer.callback_metrics metrics = { "step": self.trainer.global_step, - "valid": callback_metrics["CELoss"]["valid"].detach(), + "valid": callback_metrics["CELoss/valid"].detach(), "valid_aa_precision": callback_metrics[ - "AA precision at coverage=1" - ]["valid"].detach(), + "AA precision at coverage=1/valid" + ].detach(), "valid_pep_precision": callback_metrics[ - "Peptide precision at coverage=1" - ]["valid"].detach(), + "Peptide precision at coverage=1/valid" + ].detach(), } self._history.append(metrics) self._log_history() - def on_predict_epoch_end( - self, results: List[List[Tuple[np.ndarray, List[str], torch.Tensor]]] - ) -> None: + def on_predict_epoch_end(self) -> None: """ Write the predicted peptide sequences and amino acid scores to the output file. @@ -868,7 +864,7 @@ def on_predict_epoch_end( peptide_score, aa_scores, ) in itertools.chain.from_iterable( - itertools.chain.from_iterable(results) + self.trainer.predict_loop.predictions ): if len(peptide) == 0: continue diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index fb5deeba..add48c93 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -14,6 +14,8 @@ from depthcharge.data import AnnotatedSpectrumIndex, SpectrumIndex from pytorch_lightning.strategies import DDPStrategy +from lightning.pytorch.accelerators import find_usable_cuda_devices + from .. import utils from ..data import ms_io from ..denovo.dataloaders import DeNovoDataModule @@ -96,8 +98,12 @@ def _execute_existing( model_filename, ) raise FileNotFoundError("Could not find the trained model weights") + map_location = None + if torch.cuda.device_count() == 0: + map_location = "cpu" model = Spec2Pep().load_from_checkpoint( model_filename, + map_location=map_location, dim_model=config["dim_model"], n_head=config["n_head"], dim_feedforward=config["dim_feedforward"], @@ -158,7 +164,6 @@ def _execute_existing( # Create the Trainer object. trainer = pl.Trainer( accelerator="auto", - auto_select_gpus=True, devices=_get_devices(config["no_gpu"]), logger=config["logger"], max_epochs=config["max_epochs"], @@ -304,7 +309,6 @@ def train( trainer = pl.Trainer( accelerator="auto", - auto_select_gpus=True, callbacks=callbacks, devices=_get_devices(config["no_gpu"]), enable_checkpointing=config["save_model"], @@ -352,7 +356,7 @@ def _get_peak_filenames( ] -def _get_strategy() -> Optional[DDPStrategy]: +def _get_strategy() -> Union[DDPStrategy, str]: """ Get the strategy for the Trainer. @@ -362,16 +366,16 @@ def _get_strategy() -> Optional[DDPStrategy]: Returns ------- - Optional[DDPStrategy] + Union[DDPStrategy,str] The strategy parameter for the Trainer. """ if torch.cuda.device_count() > 1: return DDPStrategy(find_unused_parameters=False, static_graph=True) - return None + return "auto" -def _get_devices(no_gpu: bool) -> Union[int, str]: +def _get_devices(no_gpu: bool) -> Union[List[int], str]: """ Get the number of GPUs/CPUs for the Trainer to use. @@ -382,16 +386,14 @@ def _get_devices(no_gpu: bool) -> Union[int, str]: Returns ------- - Union[int, str] - The number of GPUs/CPUs to use, or "auto" to let PyTorch Lightning - determine the appropriate number of devices. + Union[List[int], str] + A list of CUDA GPU devices to use, or "auto" to let PyTorch Lightning determine + the appropriate number of devices. """ if not no_gpu and any( operator.attrgetter(device + ".is_available")(torch)() for device in ("cuda",) ): - return -1 - elif not (n_workers := utils.n_workers()): - return "auto" + return find_usable_cuda_devices() else: - return n_workers + return "auto" diff --git a/pyproject.toml b/pyproject.toml index db0fbe35..0ba22c32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "pandas", "psutil", "PyGithub", - "pytorch-lightning>=1.7,<2.0", + "lightning>=2.0.0", "PyYAML", "requests", "scikit-learn",