Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eval metrics and circular import bug fix. #380

Merged
merged 33 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
617fcb8
eval metrics bug fix
Lilferrit Sep 12, 2024
a52ba83
better eval metrics bug fix
Lilferrit Sep 12, 2024
81f4515
eval metrics bug fix
Lilferrit Sep 12, 2024
e30b674
better eval metrics bug fix
Lilferrit Sep 12, 2024
7b6bab3
eval stats unit test, circular import fix
Lilferrit Sep 16, 2024
ddbc93a
log metrics unit test
Lilferrit Sep 17, 2024
00fd170
resolved upstream merge conflict
Lilferrit Sep 17, 2024
9d4109e
removed unused import
Lilferrit Sep 17, 2024
86747d9
log metrics refactor, additional log metrics test case
Lilferrit Sep 19, 2024
c863b4a
aa_match_batch hanles none, additional skipped spectra test cases
Lilferrit Sep 20, 2024
34c456d
Log optimizer and training metrics to CSV file (#376)
Lilferrit Sep 20, 2024
8f21edb
aa_match_batch and aa_match handle None
Lilferrit Sep 23, 2024
217eeb8
top_match eval metrics warning
Lilferrit Sep 23, 2024
3b27582
removed unused import
Lilferrit Sep 17, 2024
4e89028
log metrics refactor, additional log metrics test case
Lilferrit Sep 19, 2024
64a681f
aa_match_batch hanles none, additional skipped spectra test cases
Lilferrit Sep 20, 2024
a3d5763
aa_match_batch and aa_match handle None
Lilferrit Sep 23, 2024
8be20ab
top_match eval metrics warning
Lilferrit Sep 23, 2024
60d4159
Merge branch 'eval-metrics-fix' of github.com:Noble-Lab/casanovo into…
Lilferrit Sep 23, 2024
5f38ea8
eval metrics bug fix
Lilferrit Sep 12, 2024
8b6e925
better eval metrics bug fix
Lilferrit Sep 12, 2024
bacf243
eval stats unit test, circular import fix
Lilferrit Sep 16, 2024
5bbbe6f
log metrics unit test
Lilferrit Sep 17, 2024
4788fab
removed unused import
Lilferrit Sep 17, 2024
c473f20
log metrics refactor, additional log metrics test case
Lilferrit Sep 19, 2024
63ac6ad
aa_match_batch hanles none, additional skipped spectra test cases
Lilferrit Sep 20, 2024
7b4b6e6
aa_match_batch and aa_match handle None
Lilferrit Sep 23, 2024
78bb897
top_match eval metrics warning
Lilferrit Sep 23, 2024
fb975b2
removed unused import
Lilferrit Sep 17, 2024
692cd7e
log metrics refactor, additional log metrics test case
Lilferrit Sep 19, 2024
7740a77
metrics file logging bug fix
Lilferrit Sep 23, 2024
e9bb5ec
merge conflicts
Lilferrit Sep 23, 2024
60524af
aa_match test cases, minor aa_match refactor
Lilferrit Sep 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 2 additions & 39 deletions casanovo/data/ms_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,54 +2,17 @@

import collections
import csv
import dataclasses
import operator
import os
import re
from pathlib import Path
from typing import List, Tuple, Iterable
from typing import List

import natsort

from .. import __version__
from ..config import Config


@dataclasses.dataclass
class PepSpecMatch:
"""
Peptide Spectrum Match (PSM) dataclass

Parameters
----------
sequence : str
The amino acid sequence of the peptide.
spectrum_id : Tuple[str, str]
A tuple containing the spectrum identifier in the form
(spectrum file name, spectrum file idx)
peptide_score : float
Score of the match between the full peptide sequence and the
spectrum.
charge : int
The precursor charge state of the peptide ion observed in the spectrum.
calc_mz : float
The calculated mass-to-charge ratio (m/z) of the peptide based on its
sequence and charge state.
exp_mz : float
The observed (experimental) precursor mass-to-charge ratio (m/z) of the
peptide as detected in the spectrum.
aa_scores : Iterable[float]
A list of scores for individual amino acids in the peptide
sequence, where len(aa_scores) == len(sequence)
"""

sequence: str
spectrum_id: Tuple[str, str]
peptide_score: float
charge: int
calc_mz: float
exp_mz: float
aa_scores: Iterable[float]
from .psm import PepSpecMatch


class MztabWriter:
Expand Down
41 changes: 41 additions & 0 deletions casanovo/data/psm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Peptide spectrum match dataclass"""

import dataclasses
from typing import Tuple, Iterable


@dataclasses.dataclass
class PepSpecMatch:
"""
Peptide Spectrum Match (PSM) dataclass

Parameters
----------
sequence : str
The amino acid sequence of the peptide.
spectrum_id : Tuple[str, str]
A tuple containing the spectrum identifier in the form
(spectrum file name, spectrum file idx)
peptide_score : float
Score of the match between the full peptide sequence and the
spectrum.
charge : int
The precursor charge state of the peptide ion observed in the spectrum.
calc_mz : float
The calculated mass-to-charge ratio (m/z) of the peptide based on its
sequence and charge state.
exp_mz : float
The observed (experimental) precursor mass-to-charge ratio (m/z) of the
peptide as detected in the spectrum.
aa_scores : Iterable[float]
A list of scores for individual amino acids in the peptide
sequence, where len(aa_scores) == len(sequence)
"""

sequence: str
spectrum_id: Tuple[str, str]
peptide_score: float
charge: int
calc_mz: float
exp_mz: float
aa_scores: Iterable[float]
6 changes: 6 additions & 0 deletions casanovo/denovo/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,14 @@
# Split peptides into individual AAs if necessary.
if isinstance(peptide1, str):
peptide1 = re.split(r"(?<=.)(?=[A-Z])", peptide1)
elif peptide1 is None:
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
peptide1 = []

Check warning on line 229 in casanovo/denovo/evaluate.py

View check run for this annotation

Codecov / codecov/patch

casanovo/denovo/evaluate.py#L228-L229

Added lines #L228 - L229 were not covered by tests

if isinstance(peptide2, str):
peptide2 = re.split(r"(?<=.)(?=[A-Z])", peptide2)
elif peptide2 is None:
peptide2 = []

n_aa1, n_aa2 = n_aa1 + len(peptide1), n_aa2 + len(peptide2)
aa_matches_batch.append(
aa_match(
Expand Down
4 changes: 2 additions & 2 deletions casanovo/denovo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from . import evaluate
from .. import config
from ..data import ms_io
from ..data import ms_io, psm

logger = logging.getLogger("casanovo")

Expand Down Expand Up @@ -914,7 +914,7 @@ def on_predict_batch_end(
if len(peptide) == 0:
continue
self.out_writer.psms.append(
ms_io.PepSpecMatch(
psm.PepSpecMatch(
sequence=peptide,
spectrum_id=tuple(spectrum_i),
peptide_score=peptide_score,
Expand Down
24 changes: 17 additions & 7 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,18 +163,28 @@
Index containing the annotated spectra used to generate model
predictions
"""
model_output = [psm.sequence for psm in self.writer.psms]
spectrum_annotations = [
test_index[i][4] for i in range(test_index.n_spectra)
]
seq_pred = []
seq_true = []
pred_idx = 0

with test_index as t_ind:
for true_idx in range(t_ind.n_spectra):
seq_true.append(t_ind[true_idx][4])
if pred_idx < len(self.writer.psms) and self.writer.psms[
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
pred_idx
].spectrum_id == t_ind.get_spectrum_id(true_idx):
seq_pred.append(self.writer.psms[pred_idx].sequence)
pred_idx += 1
else:
seq_pred.append(None)

aa_precision, _, pep_precision = aa_match_metrics(
*aa_match_batch(
spectrum_annotations,
model_output,
seq_true,
seq_pred,
depthcharge.masses.PeptideMass().masses,
)
)

logger.info("Peptide Precision: %.2f%%", 100 * pep_precision)
logger.info("Amino Acid Precision: %.2f%%", 100 * aa_precision)

Expand Down Expand Up @@ -272,7 +282,7 @@
tb_summarywriter = None
if self.config.tb_summarywriter:
if self.output_dir is None:
logger.warning(

Check warning on line 285 in casanovo/denovo/model_runner.py

View check run for this annotation

Codecov / codecov/patch

casanovo/denovo/model_runner.py#L285

Added line #L285 was not covered by tests
"Can not create tensorboard because the output directory "
"is not set in the model runner."
)
Expand Down
2 changes: 1 addition & 1 deletion casanovo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import psutil
import torch

from .data.ms_io import PepSpecMatch
from .data.psm import PepSpecMatch


SCORE_BINS = [0.0, 0.5, 0.9, 0.95, 0.99]
Expand Down
170 changes: 170 additions & 0 deletions tests/unit_tests/test_runner.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Unit tests specifically for the model_runner module."""

import unittest.mock
from pathlib import Path

import pytest
import torch

from casanovo.config import Config
from casanovo.data.psm import PepSpecMatch
from casanovo.denovo.model_runner import ModelRunner


Expand Down Expand Up @@ -282,3 +284,171 @@ def test_evaluate(
)

result_file.unlink()


def test_log_metrics(monkeypatch, tiny_config):
def get_mock_index(psm_list):
mock_test_index = unittest.mock.MagicMock()
mock_test_index.__enter__.return_value = mock_test_index
mock_test_index.__exit__.return_value = False
mock_test_index.n_spectra = len(psm_list)
mock_test_index.get_spectrum_id = lambda idx: psm_list[idx].spectrum_id

mock_spectra = [
(None, None, None, None, curr_psm.sequence)
for curr_psm in psm_list
]
mock_test_index.__getitem__.side_effect = lambda idx: mock_spectra[idx]
return mock_test_index

def get_mock_psm(sequence, spectrum_id):
return PepSpecMatch(
sequence=sequence,
spectrum_id=spectrum_id,
peptide_score=None,
charge=None,
exp_mz=None,
aa_scores=None,
calc_mz=None,
)

with monkeypatch.context() as ctx:
mock_logger = unittest.mock.MagicMock()
ctx.setattr("casanovo.denovo.model_runner.logger", mock_logger)

with ModelRunner(Config(tiny_config)) as runner:
runner.writer = unittest.mock.MagicMock()

# Test 100% peptide precision
infer_psms = [
get_mock_psm("PEP", ("foo", "index=1")),
get_mock_psm("PET", ("foo", "index=2")),
]

act_psms = [
get_mock_psm("PEP", ("foo", "index=1")),
get_mock_psm("PET", ("foo", "index=2")),
]

runner.writer.psms = infer_psms
mock_index = get_mock_index(act_psms)
runner.log_metrics(mock_index)

pep_precision = mock_logger.info.call_args_list[-2][0][1]
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
aa_precision = mock_logger.info.call_args_list[-1][0][1]
assert pep_precision == pytest.approx(100)
assert aa_precision == pytest.approx(100)

# Test 50% peptide precision (one wrong)
infer_psms = [
get_mock_psm("PEP", ("foo", "index=1")),
get_mock_psm("PET", ("foo", "index=2")),
]

act_psms = [
get_mock_psm("PEP", ("foo", "index=1")),
get_mock_psm("PEP", ("foo", "index=2")),
]

runner.writer.psms = infer_psms
mock_index = get_mock_index(act_psms)
runner.log_metrics(mock_index)

pep_precision = mock_logger.info.call_args_list[-2][0][1]
aa_precision = mock_logger.info.call_args_list[-1][0][1]
assert pep_precision == pytest.approx(100 * (1 / 2))
assert aa_precision == pytest.approx(100 * (5 / 6))

# Test skipped spectra
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
act_psms = [
get_mock_psm("PEP", ("foo", "index=1")),
get_mock_psm("PET", ("foo", "index=2")),
get_mock_psm("PEI", ("foo", "index=3")),
get_mock_psm("PEG", ("foo", "index=4")),
get_mock_psm("PEA", ("foo", "index=5")),
]

infer_psms = [
get_mock_psm("PEP", ("foo", "index=1")),
get_mock_psm("PET", ("foo", "index=2")),
get_mock_psm("PEI", ("foo", "index=3")),
get_mock_psm("PEA", ("foo", "index=5")),
]

runner.writer.psms = infer_psms
mock_index = get_mock_index(act_psms)
runner.log_metrics(mock_index)

pep_precision = mock_logger.info.call_args_list[-2][0][1]
aa_precision = mock_logger.info.call_args_list[-1][0][1]
assert pep_precision == pytest.approx(100 * (4 / 5))
assert aa_precision == pytest.approx(100)

infer_psms = [
get_mock_psm("PEP", ("foo", "index=1")),
get_mock_psm("PET", ("foo", "index=2")),
get_mock_psm("PEI", ("foo", "index=3")),
get_mock_psm("PEG", ("foo", "index=4")),
]

runner.writer.psms = infer_psms
mock_index = get_mock_index(act_psms)
runner.log_metrics(mock_index)

pep_precision = mock_logger.info.call_args_list[-2][0][1]
aa_precision = mock_logger.info.call_args_list[-1][0][1]
assert pep_precision == pytest.approx(100 * (4 / 5))
assert aa_precision == pytest.approx(100)

infer_psms = [
get_mock_psm("PEP", ("foo", "index=1")),
get_mock_psm("PEI", ("foo", "index=3")),
]

runner.writer.psms = infer_psms
mock_index = get_mock_index(act_psms)
runner.log_metrics(mock_index)

pep_precision = mock_logger.info.call_args_list[-2][0][1]
aa_precision = mock_logger.info.call_args_list[-1][0][1]
assert pep_precision == pytest.approx(100 * (2 / 5))
assert aa_precision == pytest.approx(100)

infer_psms = [
get_mock_psm("PEP", ("foo", "index=1")),
get_mock_psm("PEA", ("foo", "index=5")),
]

runner.writer.psms = infer_psms
mock_index = get_mock_index(act_psms)
runner.log_metrics(mock_index)

pep_precision = mock_logger.info.call_args_list[-2][0][1]
aa_precision = mock_logger.info.call_args_list[-1][0][1]
assert pep_precision == pytest.approx(100 * (2 / 5))
assert aa_precision == pytest.approx(100)

# Test un-inferred spectra
act_psms = [
get_mock_psm("PEP", ("foo", "index=1")),
get_mock_psm("PET", ("foo", "index=2")),
get_mock_psm("PEI", ("foo", "index=3")),
get_mock_psm("PEG", ("foo", "index=4")),
]

infer_psms = [
get_mock_psm("PE", ("foo", "index=1")),
get_mock_psm("PE", ("foo", "index=2")),
get_mock_psm("PE", ("foo", "index=3")),
get_mock_psm("PE", ("foo", "index=4")),
get_mock_psm("PE", ("foo", "index=5")),
]

runner.writer.psms = infer_psms
mock_index = get_mock_index(act_psms)
runner.log_metrics(mock_index)

pep_precision = mock_logger.info.call_args_list[-2][0][1]
aa_precision = mock_logger.info.call_args_list[-1][0][1]
assert pep_precision == pytest.approx(0)
assert aa_precision == pytest.approx(100)
Loading