Skip to content

Commit

Permalink
Add pred acc and ROC curve
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Ramirez committed Sep 5, 2024
1 parent 15039c7 commit b67b8bb
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 2 deletions.
108 changes: 108 additions & 0 deletions pf2/figures/figureA9.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
Figure A8:
"""
import numpy as np
import pandas as pd
import anndata
from sklearn.metrics import accuracy_score
import seaborn as sns
from ..data_import import convert_to_patients, import_meta
from ..predict import predict_mortality
from .common import subplotLabel, getSetup
from sklearn.metrics import RocCurveDisplay


def makeFigure():
"""Get a list of the axis objects and create a figure."""
ax, f = getSetup((6, 3), (1, 2))
subplotLabel(ax)

X = anndata.read_h5ad("/opt/northwest_bal/full_fitted.h5ad")

meta = import_meta()
conversions = convert_to_patients(X)

patient_factor = pd.DataFrame(
X.uns["Pf2_A"],
index=conversions,
columns=np.arange(X.uns["Pf2_A"].shape[1]) + 1,
)
meta = meta.loc[patient_factor.index, :]

plsr_acc_df = plsr_acc(patient_factor, meta)

sns.barplot(data=plsr_acc_df, ax=ax[0])
ax[0].set(ylim=[0, 1], ylabel="Accuracy")

plot_plsr_auc_roc(patient_factor, meta, ax[1])

return f


def plsr_acc(patient_factor_matrix, meta_data):
"""Runs PLSR and obtains average prediction accuracy"""

acc_df = pd.DataFrame(columns=["Overall", "C19", "nC19"])

probabilities, labels = predict_mortality(
patient_factor_matrix,
meta_data,
proba=True
)

probabilities = probabilities.round().astype(int)
meta_data = meta_data.loc[~meta_data.index.duplicated()].loc[labels.index]

covid_acc = accuracy_score(
labels.loc[meta_data.loc[:, "patient_category"] == "COVID-19"],
probabilities.loc[meta_data.loc[:, "patient_category"] == "COVID-19"]
)
nc_acc = accuracy_score(
labels.loc[meta_data.loc[:, "patient_category"] != "COVID-19"],
probabilities.loc[meta_data.loc[:, "patient_category"] != "COVID-19"]
)
acc = accuracy_score(labels, probabilities)

acc_df.loc[
0,
:
] = [acc, covid_acc, nc_acc]

return acc_df


def plot_plsr_auc_roc(patient_factor_matrix, meta_data, ax):
"""Runs PLSR and plots ROC AUC based on actual and prediction labels"""

probabilities, labels = predict_mortality(
patient_factor_matrix,
meta_data,
proba=True
)

probabilities = probabilities.round().astype(int)
meta_data = meta_data.loc[~meta_data.index.duplicated()].loc[labels.index]

# covid_acc = accuracy_score(
# labels.loc[meta_data.loc[:, "patient_category"] == "COVID-19"],
# probabilities.loc[meta_data.loc[:, "patient_category"] == "COVID-19"]
# )
# nc_acc = accuracy_score(
# labels.loc[meta_data.loc[:, "patient_category"] != "COVID-19"],
# probabilities.loc[meta_data.loc[:, "patient_category"] != "COVID-19"]
# )
# acc =
# # c19, nc19 = predict_mortality(
# # patient_factor_matrix,
# # meta_data,
# # auc_roc=True
# # )
# print(c19[0])
# print(c19[1])
RocCurveDisplay.from_predictions(labels.loc[meta_data.loc[:, "patient_category"] == "COVID-19"],
probabilities.loc[meta_data.loc[:, "patient_category"] == "COVID-19"],
ax=ax, name="C19")
RocCurveDisplay.from_predictions(labels.loc[meta_data.loc[:, "patient_category"] != "COVID-19"],
probabilities.loc[meta_data.loc[:, "patient_category"] != "COVID-19"],
ax=ax, name="nC19")

7 changes: 5 additions & 2 deletions pf2/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,14 @@ def run_plsr(
plsr.coef_.squeeze(),
index=data.columns
)

if proba:
return probabilities, plsr

else:
predicted = probabilities.round().astype(int)
return predicted, plsr



def predict_mortality(
Expand Down Expand Up @@ -101,7 +103,7 @@ def predict_mortality(
meta.loc[:, "patient_category"] != "COVID-19",
"binary_outcome"
]

predictions = pd.Series(index=data.index)
predictions.loc[meta.loc[:, "patient_category"] == "COVID-19"], c_plsr = \
run_plsr(
Expand All @@ -114,6 +116,7 @@ def predict_mortality(

if proba:
return predictions, labels

else:
predicted = predictions.round().astype(int)
return accuracy_score(labels, predicted), (c_plsr, nc_plsr)

0 comments on commit b67b8bb

Please sign in to comment.