Skip to content

Commit

Permalink
Fix type hints and line lengths
Browse files Browse the repository at this point in the history
  • Loading branch information
bittremieux committed Nov 10, 2024
1 parent 1efd9dd commit f679cdc
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 74 deletions.
142 changes: 75 additions & 67 deletions casanovo/casanovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import urllib.parse
import warnings
from pathlib import Path
from typing import Optional, Tuple, List
from typing import Optional, Tuple

warnings.formatwarning = lambda message, category, *args, **kwargs: (
f"{category.__name__}: {message}"
Expand Down Expand Up @@ -62,19 +62,19 @@ def __init__(self, *args, **kwargs) -> None:
click.Option(
("-m", "--model"),
help="""
Either the model weights (.ckpt file) or a URL pointing to
the model weights file. If not provided,
Casanovo will try to download the latest release automatically.
Either the model weights (.ckpt file) or a URL pointing to the
model weights file. If not provided, Casanovo will try to
download the latest release automatically.
""",
),
click.Option(
("-d", "--output_dir"),
help="The destination directory for output files",
help="The destination directory for output files.",
type=click.Path(dir_okay=True),
),
click.Option(
("-o", "--output_root"),
help="The root name for all output files",
help="The root name for all output files.",
type=click.Path(dir_okay=False),
),
click.Option(
Expand Down Expand Up @@ -113,9 +113,9 @@ def main() -> None:
========
Casanovo de novo sequences peptides from tandem mass spectra using a
Transformer model. Casanovo currently supports mzML, mzXML, and MGF files
for de novo sequencing and annotated MGF files, such as those from
MassIVE-KB, for training new models.
Transformer model. Casanovo currently supports mzML, mzXML, and MGF
files for de novo sequencing and annotated MGF files, such as those
from MassIVE-KB, for training new models.
Links:
Expand All @@ -124,10 +124,10 @@ def main() -> None:
If you use Casanovo in your work, please cite:
- Yilmaz, M., Fondrie, W. E., Bittremieux, W., Oh, S. & Noble, W. S. De novo
mass spectrometry peptide sequencing with a transformer model. Proceedings
of the 39th International Conference on Machine Learning - ICML '22 (2022)
doi:10.1101/2022.02.07.479481.
- Yilmaz, M., Fondrie, W. E., Bittremieux, W., Oh, S. & Noble, W. S.
De novo mass spectrometry peptide sequencing with a transformer
model. Proceedings of the 39th International Conference on Machine
Learning - ICML '22 (2022) doi:10.1101/2022.02.07.479481.
"""
Expand All @@ -147,9 +147,9 @@ def main() -> None:
is_flag=True,
default=False,
help="""
Run in evaluation mode. When this flag is set the peptide and amino
acid precision will be calculated and logged at the end of the sequencing
run. All input files must be annotated MGF files if running in evaluation
Run in evaluation mode. When this flag is set the peptide and amino acid
precision will be calculated and logged at the end of the sequencing run.
All input files must be annotated MGF files if running in evaluation
mode.
""",
)
Expand Down Expand Up @@ -290,8 +290,9 @@ def train(
) -> None:
"""Train a Casanovo model on your own data.
TRAIN_PEAK_PATH must be one or more annoated MGF files, such as those
provided by MassIVE-KB, from which to train a new Casnovo model.
TRAIN_PEAK_PATH must be one or more annoated MGF files, such as
those provided by MassIVE-KB, from which to train a new Casnovo
model.
"""
output_path, output_root_name = _setup_output(
output_dir, output_root, force_overwrite, verbosity
Expand Down Expand Up @@ -324,7 +325,7 @@ def train(

@main.command()
def version() -> None:
"""Get the Casanovo version information"""
"""Get the Casanovo version information."""
versions = [
f"Casanovo: {__version__}",
f"Depthcharge: {depthcharge.__version__}",
Expand All @@ -342,20 +343,20 @@ def version() -> None:
default="casanovo.yaml",
type=click.Path(dir_okay=False),
)
def configure(output: str) -> None:
def configure(output: Path) -> None:
"""Generate a Casanovo configuration file to customize.
The casanovo configuration file is in the YAML format.
"""
Config.copy_default(output)
output = setup_logging(output, "info")
Config.copy_default(str(output))
setup_logging(output, "info")
logger.info(f"Wrote {output}\n")


def setup_logging(
log_file_path: Path,
verbosity: str,
) -> Path:
) -> None:
"""Set up the logger.
Logging occurs to the command-line and to the given log file.
Expand Down Expand Up @@ -423,48 +424,50 @@ def setup_model(
Parameters
----------
model : str | None
May be a file system path, a URL pointing to a .ckpt file, or None.
If `model` is a URL the weights will be downloaded and cached from
`model`. If `model` is `None` the weights from the latest matching
official release will be used (downloaded and cached).
May be a file system path, a URL pointing to a .ckpt file, or
None. If `model` is a URL the weights will be downloaded and
cached from `model`. If `model` is `None` the weights from the
latest matching official release will be used (downloaded and
cached).
config : str | None
Config file path. If None the default config will be used.
output_dir: : Path | str
The path to the output directory.
output_root_name : str,
The base name for the output files.
is_train : bool
Are we training? If not, we need to retrieve weights when the model is
None.
Are we training? If not, we need to retrieve weights when the
model is None.
Return
------
Tuple[Config, Path]
Initialized Casanovo config, local path to model weights if any (may be
`None` if training using random starting weights).
Initialized Casanovo config, local path to model weights if any
(may be `None` if training using random starting weights).
"""
# Read parameters from the config file.
config = Config(config)
seed_everything(seed=config["random_seed"], workers=True)

# Download model weights if these were not specified (except when training).
# Download model weights if these were not specified (except when
# training).
cache_dir = Path(appdirs.user_cache_dir("casanovo", False, opinion=False))
if model is None:
if not is_train:
try:
model = _get_model_weights(cache_dir)
except github.RateLimitExceededException:
logger.error(
"GitHub API rate limit exceeded while trying to download the "
"model weights. Please download compatible model weights "
"manually from the official Casanovo code website "
"(https://github.com/Noble-Lab/casanovo) and specify these "
"explicitly using the `--model` parameter when running "
"Casanovo."
"GitHub API rate limit exceeded while trying to download "
"the model weights. Please download compatible model "
"weights manually from the official Casanovo code website "
"(https://github.com/Noble-Lab/casanovo) and specify "
"these explicitly using the `--model` parameter when "
"running Casanovo."
)
raise PermissionError(
"GitHub API rate limit exceeded while trying to download the "
"model weights"
"GitHub API rate limit exceeded while trying to download "
"the model weights"
) from None
else:
if _is_valid_url(model):
Expand All @@ -489,29 +492,30 @@ def setup_model(
return config, model


def _get_model_weights(cache_dir: Path) -> str:
def _get_model_weights(cache_dir: Path) -> Path:
"""
Use cached model weights or download them from GitHub.
If no weights file (extension: .ckpt) is available in the cache directory,
it will be downloaded from a release asset on GitHub.
Model weights are retrieved by matching release version. If no model weights
for an identical release (major, minor, patch), alternative releases with
matching (i) major and minor, or (ii) major versions will be used.
If no matching release can be found, no model weights will be downloaded.
If no weights file (extension: .ckpt) is available in the cache
directory, it will be downloaded from a release asset on GitHub.
Model weights are retrieved by matching release version. If no model
weights for an identical release (major, minor, patch), alternative
releases with matching (i) major and minor, or (ii) major versions
will be used. If no matching release can be found, no model weights
will be downloaded.
Note that the GitHub API is limited to 60 requests from the same IP per
hour.
Note that the GitHub API is limited to 60 requests from the same IP
per hour.
Parameters
----------
cache_dir : Path
model weights cache directory path
Model weights cache directory path.
Returns
-------
str
The name of the model weights file.
Path
The path of the model weights file.
"""
os.makedirs(cache_dir, exist_ok=True)
version = utils.split_version(__version__)
Expand Down Expand Up @@ -598,11 +602,11 @@ def _setup_output(
Parameters:
-----------
output_dir : str | None
The path to the output directory. If `None`, the output directory will
be resolved to the current working directory.
The path to the output directory. If `None`, the output
directory will be resolved to the current working directory.
output_root : str | None
The base name for the output files. If `None` the output root name will
be resolved to casanovo_<current date and time>
The base name for the output files. If `None` the output root
name will be resolved to casanovo_<current date and time>
overwrite: bool
Whether to overwrite log file if it already exists in the output
directory.
Expand All @@ -612,8 +616,8 @@ def _setup_output(
Returns:
--------
Tuple[Path, str]
A tuple containing the resolved output directory and root name for
output files.
A tuple containing the resolved output directory and root name
for output files.
"""
if output_root is None:
output_root = (
Expand All @@ -627,7 +631,8 @@ def _setup_output(
if not output_path.is_dir():
output_path.mkdir(parents=True)
logger.warning(
"Target output directory %s does not exists, so it will be created.",
"Target output directory %s does not exists, so it will be "
"created.",
output_path,
)

Expand All @@ -647,8 +652,8 @@ def _get_weights_from_url(
Resolve weight file from URL
Attempt to download weight file from URL if weights are not already
cached - otherwise use cached weights. Downloaded weight files will be
cached.
cached - otherwise use cached weights. Downloaded weight files will
be cached.
Parameters
----------
Expand All @@ -657,8 +662,8 @@ def _get_weights_from_url(
cache_dir : Path
Model weights cache directory path.
force_download : Optional[bool], default=False
If True, forces a new download of the weight file even if it exists in
the cache.
If True, forces a new download of the weight file even if it
exists in the cache.
Returns
-------
Expand Down Expand Up @@ -688,7 +693,8 @@ def _get_weights_from_url(
).timestamp()
else:
logger.warning(
"Attempted HEAD request to %s yielded non-ok status code - using cached file",
"Attempted HEAD request to %s yielded non-ok status code—"
"using cached file",
file_url,
)
except (
Expand All @@ -697,7 +703,8 @@ def _get_weights_from_url(
requests.TooManyRedirects,
):
logger.warning(
"Failed to reach %s to get remote last modified time - using cached file",
"Failed to reach %s to get remote last modified time—using "
"cached file",
file_url,
)

Expand All @@ -715,8 +722,9 @@ def _download_weights(file_url: str, download_path: Path) -> None:
"""
Download weights file from URL
Download the model weights file from the specified URL and save it to the
given path. Ensures the download directory exists, and uses a progress
Download the model weights file from the specified URL and save it
to the given path. Ensures the download directory exists, and uses a
progress
bar to indicate download status.
Parameters
Expand Down
14 changes: 7 additions & 7 deletions casanovo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,16 @@ def get_report_dict(


def log_run_report(
start_time: Optional[int] = None, end_time: Optional[int] = None
start_time: Optional[float] = None, end_time: Optional[float] = None
) -> None:
"""
Log general run report
Parameters
----------
start_time : Optional[int], default=None
start_time : Optional[float], default=None
The start time of the sequencing run in seconds since the epoch.
end_time : Optional[int], default=None
end_time : Optional[float], default=None
The end time of the sequencing run in seconds since the epoch.
"""
logger.info("======= End of Run Report =======")
Expand All @@ -197,8 +197,8 @@ def log_run_report(

def log_sequencing_report(
predictions: List[PepSpecMatch],
start_time: Optional[int] = None,
end_time: Optional[int] = None,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
score_bins: List[float] = SCORE_BINS,
) -> None:
"""
Expand All @@ -210,9 +210,9 @@ def log_sequencing_report(
str, Tuple[str, str], float, float, float, float, str
]
PSM predictions
start_time : Optional[int], default=None
start_time : Optional[float], default=None
The start time of the sequencing run in seconds since the epoch.
end_time : Optional[int], default=None
end_time : Optional[float], default=None
The end time of the sequencing run in seconds since the epoch.
score_bins: List[float], Optional
Confidence scores for creating confidence score distribution,
Expand Down

0 comments on commit f679cdc

Please sign in to comment.