Skip to content

Commit

Permalink
converted to ensemble model
Browse files Browse the repository at this point in the history
  • Loading branch information
lincoln-harris committed Aug 29, 2024
1 parent 474a10b commit 08c43ce
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 68 deletions.
168 changes: 110 additions & 58 deletions lupine/lupine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,21 @@
This modules contains the `Lupine` class and the implementation of
the `impute` command. `Lupine` is the high-level implementation for
a PyTorch model for imputing protein-level quantifications using
deep matrix factorization. Missing values are imputed by taking the
a multilayer perceptron. Missing values are imputed by taking the
concatenation of the corresponding protein and run factors and
feeding them through a deep neural network.
This module implements the method's `impute` command, which fits an
ensemble of Lupine models to the provided matrix and writes a single
consensus imputed quants matrix as output.
"""
from lupine.lupine_base import LupineBase
import torch
import click
import pandas as pd
import numpy as np
import torch
import shutil

from lupine.os_utils import os
from pathlib import Path
Expand Down Expand Up @@ -113,66 +118,113 @@ def forward(self, locs):
@click.command()
@click.argument("csv", required=True, nargs=1)

@click.option("--n_prot_factors", default=128,
help="Number of protein factors", required=False, type=int)
@click.option("--n_run_factors", default=128,
help="Number of run factors", required=False, type=int)
@click.option("--n_layers", default=2,
help="Number of hidden layers", required=False, type=int)
@click.option("--n_nodes", default=1024,
help="Number of nodes per layer", required=False, type=int)
@click.option("--rand_seed", default=None, help="Random seed",
required=False, type=int)
@click.option("--outpath", required=True, nargs=1, type=str,
help="Output directory")
@click.option("--n_models", default=10,
help="The number of models to fit.", required=False, type=int)
@click.option("--biased", default=True,
help="Biased batch selection?", required=False, type=bool)
help="Biased batch selection?", required=False, type=bool)
@click.option("--device", default="cpu",
help="The device to load model on", required=False, type=str)
help="The device to load model on", required=False, type=str)
@click.option("--mode", default="run",
help="The model run mode.", required=False, type=str)
help="The model run mode.", required=False, type=str)

def impute(
csv,
n_prot_factors,
n_run_factors,
n_layers,
n_nodes,
rand_seed,
biased,
device,
mode,
csv,
outpath,
n_models,
biased,
device,
mode,
):
"""Impute missing values in a protein or peptide quantifications matrix."""

# Read in the csv
mat_pd = pd.read_csv(csv, index_col=0)
rows = list(mat_pd.index)
cols = list(mat_pd.columns)
mat = np.array(mat_pd)

test_bool = False
if mode == "Testing":
test_bool = True

Path("results/").mkdir(parents=True, exist_ok=True)

# Init the model
model = Lupine(
n_prots=mat.shape[0],
n_runs=mat.shape[1],
n_prot_factors=n_prot_factors,
n_run_factors=n_run_factors,
n_layers=n_layers,
n_nodes=n_nodes,
rand_seed=rand_seed,
testing=test_bool,
biased=biased,
device=device
)
# Fit the model
print("fitting model")
model_recon = model.fit_transform(mat)

print("done!")
model_recon_pd = \
pd.DataFrame(model_recon, index=rows, columns=cols)
pd.DataFrame(model_recon_pd, "results/lupine_recon_quants.csv")
"""
Impute missing values in a protein or peptide quantifications
matrix.
"""

# Read in the csv
mat_pd = pd.read_csv(csv, index_col=0)
rows = list(mat_pd.index)
cols = list(mat_pd.columns)
mat = np.array(mat_pd)

test_bool = False
if mode == "Testing":
test_bool = True

# Define the full hyperparam search spaces a
gen = np.random.default_rng(seed=18)
n_layers_hparam_space=[1, 2]
n_factors_hparam_space=[32, 64, 128, 256]
n_nodes_hparam_space=[256, 512, 1024, 2048]

print(" ")
print("----------------------------------")
print("-------- L U P I N E ---------")
print("----------------------------------")
print(" ")
print(f"Fitting ensemble of models on: {device}\n")

Path(outpath).mkdir(parents=True, exist_ok=True)
Path(outpath+"/tmp").mkdir(parents=True, exist_ok=True)

# The driver loop for ensemble model
for n_iter in range(0, n_models):
print(f"Fitting model {n_iter+1} of {n_models}")

# Randomly select the hparams
n_layers_curr = gen.choice(n_layers_hparam_space)
prot_factors_curr = gen.choice(n_factors_hparam_space)
run_factors_curr = gen.choice(n_factors_hparam_space)
n_nodes_curr = gen.choice(n_nodes_hparam_space)

curr_seed = gen.integers(low=1, high=1e4)

# Init an individual model
model = Lupine(
n_prots=mat.shape[0],
n_runs=mat.shape[1],
n_prot_factors=prot_factors_curr,
n_run_factors=run_factors_curr,
n_layers=n_layers_curr,
n_nodes=n_nodes_curr,
rand_seed=curr_seed,
testing=test_bool,
biased=biased,
device=device
)

# Fit the individual model
model_recon = model.fit_transform(mat)
model_recon_pd = \
pd.DataFrame(model_recon, index=rows, columns=cols)

# Write.
# These filenames may be helpful for debugging.
outpath_curr = \
outpath + "tmp/qmat_tmp_" + \
str(n_layers_curr) + "layers_" + \
str(prot_factors_curr) + "protFactors_" + \
str(run_factors_curr) + "runFactors_" + \
str(n_nodes_curr) + "nodes_" + \
str(curr_seed) + "seed" + ".csv"

model_recon_pd.to_csv(outpath_curr)

# Do the model ensembling
qmats = []
for n_iter in range(0, n_models):
curr_path = outpath + "tmp/qmat_tmp" + str(n_iter) + ".csv"
tmp = pd.read_csv(curr_path)
qmats.append(tmp)

qmats_mean = np.mean(qmats, axis=0)
outpath_ensemble = outpath + "lupine_recon_quants.csv"
pd.DataFrame(qmats_mean).to_csv(outpath_ensemble)
shutil.rmtree(outpath+"tmp")

print(" ")
print("Done!")
print("----------------------------------")
print("----------------------------------")
print(" ")
19 changes: 9 additions & 10 deletions lupine/lupine_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ class LupineBase(torch.nn.Module):
The tolerance criteria for early stopping, according to the
standard early stopping criteria
max_epochs : int, optional,
The maximum number of training epochs for the model
The maximum number of training epochs for the model.
Default 42.
patience : int, optional
The number of training epochs to wait before stopping if
it seems like the model has converged
Expand All @@ -76,7 +77,7 @@ def __init__(
learning_rate=0.01,
batch_size=128,
tolerance=0.001,
max_epochs=128,
max_epochs=1,
patience=10,
rand_seed=None,
testing=False,
Expand Down Expand Up @@ -104,13 +105,11 @@ def __init__(
torch.manual_seed(self.rand_seed)

# For writing the model state to disk
self.MODELPATH = "results/OPT_MODEL_INTERNAL.pt"

# Is there a better way to do this?
try:
os.remove(self.MODELPATH)
except FileNotFoundError:
pass
#self.MODELPATH = "results/OPT_MODEL_INTERNAL.pt"
# try:
# os.remove(self.MODELPATH)
# except FileNotFoundError:
# pass

# Init the run factors
self.run_factors = torch.nn.Parameter(
Expand Down Expand Up @@ -291,7 +290,7 @@ def fit(self, X_mat, X_val_mat=None):
curr_loss = train_loss

if curr_loss < best_loss:
torch.save(self, self.MODELPATH)
#torch.save(self, self.MODELPATH)
best_loss = curr_loss

# Evaluate early stopping:
Expand Down

0 comments on commit 08c43ce

Please sign in to comment.