Skip to content

Commit

Permalink
Finished first model revision
Browse files Browse the repository at this point in the history
  • Loading branch information
wfondrie committed Oct 28, 2023
1 parent e0be302 commit 3400de1
Showing 1 changed file with 96 additions and 91 deletions.
187 changes: 96 additions & 91 deletions mokapot/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""mokapot implements an algorithm for training machine learning models to
"""Mokapot models.
Mokapot implements an algorithm for training machine learning models to
distinguish high-scoring target peptide-spectrum matches (PSMs) from decoy PSMs
using an iterative procedure. It is the :py:class:`Model` class that contains
this logic. A :py:class:`Model` instance can be created from any object with a
Expand Down Expand Up @@ -151,12 +153,6 @@ def __init__(
# multiprocessing.
self.fold = None

# Sort out whether we need to optimize hyperparameters:
if isinstance(self.estimator, BaseSearchCV):
self._needs_cv = True
else:
self._needs_cv = False

def __repr__(self) -> str:
"""How to print the class."""
trained = {True: "A trained", False: "An untrained"}
Expand Down Expand Up @@ -241,21 +237,21 @@ def predict(self, dataset: PsmDataset) -> np.ndarray:
def fit(self, dataset: PsmDataset) -> Model:
"""Fit the model using the Percolator algorithm.
The model if trained by iteratively learning to separate decoy
PSMs from high-scoring target examples. By default, an initial
direction is chosen as the feature that best separates target
from decoy examples. A false discovery rate threshold is used to
define how high a target must score to be used as a positive
example in the next training iteration.
The model if trained by iteratively learning to separate decoy examples
from high-scoring target examples. By default, an initial direction is
chosen as the feature that best separates target from decoy examples. A
false discovery rate threshold is used to define how high a target must
score to be used as a positive example in the next training iteration.
Parameters
----------
dataset : PsmDataset object
dataset : PsmDataset
The dataset from which to train the model.
Returns
-------
self
"""
if not dataset.targets.sum():
raise ValueError(
Expand Down Expand Up @@ -288,7 +284,7 @@ def fit(self, dataset: PsmDataset) -> Model:
start_labels = start_labels[shuffled_idx]

# Prepare the model:
model = _find_hyperparameters(self, norm_feat, start_labels)
model = self._find_hyperparameters(norm_feat, start_labels)

# Begin training loop
target = start_labels
Expand All @@ -310,7 +306,10 @@ def fit(self, dataset: PsmDataset) -> Model:
num_passed.append((target == 1).sum())

LOGGER.info(
" - Iteration %i: %i training PSMs passed.", i, num_passed[i]
" - Iteration %i: %i training %s passed.",
i,
num_passed[i],
dataset.unit,
)

# If the model performs worse than what was initialized:
Expand All @@ -321,12 +320,7 @@ def fit(self, dataset: PsmDataset) -> Model:
raise RuntimeError("Model performs worse after training.")

self.estimator = model
weights = _get_weights(self.estimator, self.features)
if weights is not None:
LOGGER.info("Normalized feature weights in the learned model:")
for line in weights:
LOGGER.info(" %s", line)

self._log_weights()
self.is_trained = True
LOGGER.info("Done training.")
return self
Expand Down Expand Up @@ -462,13 +456,75 @@ def load(cls, model_file: PathLike) -> Model:

raise TypeError("This file did not contain a mokapot Model.")

def _find_hyperparameters(
self,
features: np.ndarray,
labels: np.ndarray,
) -> ClassifierMixin:
"""Find the hyperparameters for the model.
Parameters
----------
features : array-like
The features to fit the model with.
labels : array-like
The labels for each PSM (1, 0, or -1).
Returns
-------
An estimator.
"""
if isinstance(self.estimator, BaseSearchCV):
LOGGER.info("Selecting hyperparameters...")
cv_samples = features[labels.astype(bool), :]
cv_targ = (labels[labels.astype(bool)] + 1) / 2

# Fit the model
self.estimator.fit(cv_samples, cv_targ)

# Extract the best params.
best_params = self.estimator.best_params_
new_est = self.estimator.estimator
new_est.set_params(**best_params)
for param, value in best_params.items():
LOGGER.info("\t- %s = %s", param, value)

return new_est

return self.estimator

def _log_weights(self) -> None:
"""If the model is a linear model, log the weights."""
try:
weights = self.estimator.coef_
intercept = self.estimator.intercept_
criteria = [
weights.shape[0] == 1,
weights.shape[1] == len(self.features),
len(intercept) <= 1,
]
if not all(criteria):
raise ValueError

weights = list(weights.flatten())
except (AttributeError, ValueError):
LOGGER.info("Normalized feature weights in the learned model:")
col_width = max(len(f) for f in self.features) + 2
LOGGER.info("Feature %s Weight", " " * (col_width - 8))
for weight, feature in zip(weights, self.features):
space = " " * (col_width - len(feature))
LOGGER.info(" %s", feature + space + str(weight))

LOGGER.info(
" intercept%s", " " * (col_width - 9) + str(intercept[0])
)


class PercolatorModel(Model):
"""A model that emulates Percolator.
Create linear support vector machine (SVM) model that is similar
to the one used by Percolator. This is the default model used by
mokapot.
Create linear support vector machine (SVM) model that is similar to the one
used by Percolator. This is the default model used by mokapot.
Parameters
----------
Expand Down Expand Up @@ -529,6 +585,7 @@ class PercolatorModel(Model):
grid search.
rng : numpy.random.Generator
The random number generator.
"""

def __init__(
Expand Down Expand Up @@ -565,7 +622,8 @@ def __init__(


class DummyScaler:
"""
"""A dummy scaler.
Implements the interface of scikit-learn scalers, but does
nothing to the data. This simplifies the training code.
Expand All @@ -585,7 +643,7 @@ def transform(self, x: np.ndarray) -> np.ndarray:
return x


def save_model(model: Model, out_file: PathLike):
def save_model(model: Model, out_file: PathLike) -> Path:
"""
Save a :py:class:`mokapot.model.Model` object to a file.
Expand All @@ -610,75 +668,22 @@ def save_model(model: Model, out_file: PathLike):
return model.save(out_file)


def _find_hyperparameters(model, features, labels):
"""
Find the hyperparameters for the model.
def load_model(model_file: PathLike) -> Model:
"""Load a saved model for mokapot.
Parameters
----------
model : a mokapot.Model
The model to fit.
features : array-like
The features to fit the model with.
labels : array-like
The labels for each PSM (1, 0, or -1).
model_file : PathLike
The name of file from which to load the model.
Returns
-------
An estimator.
"""
if model._needs_cv:
LOGGER.info("Selecting hyperparameters...")
cv_samples = features[labels.astype(bool), :]
cv_targ = (labels[labels.astype(bool)] + 1) / 2
mokapot.model.Model
The loaded mokapot model.
# Fit the model
model.estimator.fit(cv_samples, cv_targ)

# Extract the best params.
best_params = model.estimator.best_params_
new_est = model.estimator.estimator
new_est.set_params(**best_params)
model._needs_cv = False
for param, value in best_params.items():
LOGGER.info("\t- %s = %s", param, value)
else:
new_est = model.estimator

return new_est


def _get_weights(model, features):
"""
If the model is a linear model, parse the weights to a list of strings.
Parameters
----------
model : estimator
An sklearn linear_model object
features : list of str
The feature names, in order.
Returns
-------
list of str
The weights associated with each feature.
Warnings
--------
Unpickling data in Python is unsafe. Make sure that the model is from
a source that you trust.
"""
try:
weights = model.coef_
intercept = model.intercept_
assert weights.shape[0] == 1
assert weights.shape[1] == len(features)
assert len(intercept) == 1
weights = list(weights.flatten())
except (AttributeError, AssertionError):
return None

col_width = max([len(f) for f in features]) + 2
txt_out = ["Feature" + " " * (col_width - 7) + "Weight"]
for weight, feature in zip(weights, features):
space = " " * (col_width - len(feature))
txt_out.append(feature + space + str(weight))

txt_out.append("intercept" + " " * (col_width - 9) + str(intercept[0]))
return txt_out
return Model.load(model_file)

0 comments on commit 3400de1

Please sign in to comment.