Skip to content

Commit

Permalink
Updated model tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wfondrie committed Oct 28, 2023
1 parent 3400de1 commit a4e9725
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 44 deletions.
2 changes: 1 addition & 1 deletion mokapot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Initialize the mokapot package."""
from .confidence import PsmConfidence
from .dataset import PsmDataset
from .model import Model, PercolatorModel, load_model, save_model
from .parsers.fasta import digest, make_decoys, read_fasta
from .parsers.pin import percolator_to_df, read_pin
from .schema import PsmSchema
from .version import _get_version
from .writers import to_csv, to_flashlfq, to_parquet, to_txt

# from .model import Model, PercolatorModel, save_model, load_model
# from .brew import brew
# from .parsers.pepxml import read_pepxml

Expand Down
4 changes: 1 addition & 3 deletions mokapot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ def __init__(
self.schema = schema
self.eval_fdr = eval_fdr
self.proteins = proteins

# Private:
self._unit = unit
self.unit = unit
self._len = None # We cache this for speed.

# Try and read data.
Expand Down
15 changes: 8 additions & 7 deletions mokapot/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,18 @@ def __init__(
LOGGER.info(
" - %i target %s and %i decoy %s detected.",
num_targets,
self._unit,
self.unit,
num_decoys,
self._unit,
self.unit,
)

# Validate the target column.
if not num_targets:
raise ValueError(f"No target {self._unit} were detected.")
raise ValueError(f"No target {self.unit} were detected.")
if not num_decoys:
raise ValueError(f"No decoy {self._unit} were detected.")
raise ValueError(f"No decoy {self.unit} were detected.")
if not len(self):
raise ValueError(f"No {self._unit} were detected.")
raise ValueError(f"No {self.unit} were detected.")

@property
def features(self) -> np.ndarray:
Expand Down Expand Up @@ -164,8 +164,9 @@ def update_labels(
training.
"""
if len(scores.shape) > 1 and sum(np.array(scores.shape) > 0) == 1:
scores = scores.flatten()
scores = np.array(scores).squeeze()
if len(scores.shape) > 1:
raise ValueError("scores must be one dimensional.")

qvals = qvalues.tdc(scores, target=self.targets, desc=desc)
unlabeled = np.logical_and(qvals > self.eval_fdr, self.targets)
Expand Down
61 changes: 28 additions & 33 deletions tests/unit_tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
"""Test that models work as expected"""
import pytest
import mokapot
"""Test that models work as expected."""
import numpy as np
import pandas as pd
from sklearn.svm import LinearSVC
import polars as pl
import pytest
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.model_selection import GridSearchCV
from sklearn.exceptions import NotFittedError
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.svm import LinearSVC

import mokapot


def test_model_init():
"""Test that a model initializes correctly"""
"""Test that a model initializes correctly."""
model = mokapot.Model(
LogisticRegression(),
scaler=MinMaxScaler(),
train_fdr=0.05,
max_iter=1,
direction="score",
override=True,
subset_max_train=500,
shuffle=False,
)

Expand All @@ -29,7 +29,6 @@ def test_model_init():
assert model.max_iter == 1
assert model.direction == "score"
assert model.override
assert model.subset_max_train == 500
assert not model.shuffle
assert not model.is_trained

Expand All @@ -41,14 +40,13 @@ def test_model_init():


def test_perc_init():
"""Test the initialization of a PercolatorModel"""
"""Test the initialization of a PercolatorModel."""
model = mokapot.PercolatorModel(
scaler="as-is",
train_fdr=0.05,
max_iter=1,
direction="score",
override=True,
subset_max_train=500,
)
assert isinstance(model.estimator, GridSearchCV)
assert isinstance(model.estimator.estimator, LinearSVC)
Expand All @@ -57,11 +55,10 @@ def test_perc_init():
assert model.max_iter == 1
assert model.direction == "score"
assert model.override
assert model.subset_max_train == 500


def test_model_fit(psms):
"""Test that model fitting works"""
"""Test that model fitting works."""
model = mokapot.Model(LogisticRegression(), train_fdr=0.05, max_iter=1)
model.fit(psms)

Expand All @@ -73,29 +70,25 @@ def test_model_fit(psms):
assert isinstance(model.estimator, LogisticRegression)
assert model.is_trained

no_targets = pd.DataFrame({"targets": [False] * 100})
class DummyData:
def __init__(self, tgt):
self.targets = np.array(tgt)
self.unit = "blah"

def __len__(self):
return len(self.targets)

no_targets = DummyData([False] * 100)
with pytest.raises(ValueError):
model.fit(no_targets)

no_decoys = pd.DataFrame({"targets": [True] * 100})
no_decoys = DummyData([True] * 100)
with pytest.raises(ValueError):
model.fit(no_decoys)


def test_model_fit_large_subset(psms):
model = mokapot.Model(
LogisticRegression(),
train_fdr=0.05,
max_iter=1,
subset_max_train=2_000_000_000,
)
model.fit(psms)

assert model.is_trained


def test_model_predict(psms):
"""Test predictions"""
"""Test predictions."""
model = mokapot.Model(LogisticRegression(), train_fdr=0.05, max_iter=1)

try:
Expand All @@ -109,14 +102,16 @@ def test_model_predict(psms):
assert len(scores) == len(psms)

# The case where a model is trained on a dataset with different features:
psms._data["blah"] = np.random.randn(len(psms))
psms._feature_columns = ("score", "blah")
psms._data = psms.data.with_columns(
pl.Series(np.random.randn(len(psms))).alias("blah")
)
psms.schema.features = ["score", "blah"]
with pytest.raises(ValueError):
model.predict(psms)


def test_model_persistance(tmp_path):
"""test that we can save and load a model"""
"""Test that we can save and load a model."""
model_file = str(tmp_path / "model.pkl")

model = mokapot.Model(LogisticRegression(), train_fdr=0.05, max_iter=1)
Expand All @@ -127,7 +122,7 @@ def test_model_persistance(tmp_path):


def test_dummy_scaler():
"""Test the DummyScaler class"""
"""Test the DummyScaler class."""
data = np.random.default_rng(42).normal(0, 1, (20, 10))
scaler = mokapot.model.DummyScaler()
assert (data == scaler.fit_transform(data)).all()
Expand Down

0 comments on commit a4e9725

Please sign in to comment.