Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update plsr #70

Merged
merged 2 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pf2/figures/figureA4.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def makeFigure():
columns=[f"Cmp. {i}" for i in np.arange(1, X.uns["Pf2_A"].shape[1] + 1)],
)

pc_df = partial_correlation_matrix(condition_factors_df, f)
pc_df = partial_correlation_matrix(condition_factors_df)

f = plot_partial_correlation_matrix(pc_df, f)

Expand Down
32 changes: 14 additions & 18 deletions pf2/figures/figureA8.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from .common import subplotLabel, getSetup
from ..tensor import correct_conditions
from .figureA4 import partial_correlation_matrix
from ..data_import import condition_factors_meta
from ..figures.commonFuncs.plotGeneral import bal_combine_bo_covid


def makeFigure():
Expand All @@ -23,13 +25,13 @@ def makeFigure():
X.uns["Pf2_A"] += np.median(X.uns["Pf2_A"], axis=0)
X.uns["Pf2_A"] = np.log(X.uns["Pf2_A"])

condition_factors_df = pd.DataFrame(
data=X.uns["Pf2_A"],
columns=[f"Cmp. {i}" for i in np.arange(1, X.uns["Pf2_A"].shape[1] + 1)],
)
pc_df = partial_correlation_matrix(condition_factors_df)
pc_df = remove_low_pc_cmp(pc_df, abs_threshold=0.2)
cmp_columns = [f"Cmp. {i}" for i in np.arange(1, X.uns["Pf2_A"].shape[1] + 1)]

cond_fact_meta_df = condition_factors_meta(X)
cond_fact_meta_df = bal_combine_bo_covid(cond_fact_meta_df)

pc_df = partial_correlation_matrix(cond_fact_meta_df[cmp_columns])
pc_df = remove_low_pc_cmp(pc_df, abs_threshold=0.4)

pc_df["Var1"] = pc_df["Var1"].map(lambda x: x.lstrip("Cmp. ")).astype(int)
pc_df["Var2"] = pc_df["Var2"].map(lambda x: x.lstrip("Cmp. ")).astype(int)
Expand All @@ -42,7 +44,7 @@ def makeFigure():
cmp1 = pc_abs_df.iloc[-(i+1), 0]
cmp2 = pc_abs_df.iloc[-(i+1), 1]
plot_pair_gene_factors(X, cmp1, cmp2, ax[(3*i)])
plot_pair_cond_factors(X, cmp1, cmp2, ax[(3*i)+1])
plot_pair_cond_factors(cond_fact_meta_df, cmp1, cmp2, ax[(3*i)+1], label="Status")
plot_pair_wp(X, cmp1, cmp2, ax[(3*i)+2], frac=.001)


Expand All @@ -57,20 +59,14 @@ def plot_pair_gene_factors(X: anndata.AnnData, cmp1: int, cmp2: int, ax: Axes):
df = pd.DataFrame(
data=cmpWeights.transpose(), columns=[f"Cmp. {cmp1}", f"Cmp. {cmp2}"]
)
sns.scatterplot(data=df, x=f"Cmp. {cmp1}", y=f"Cmp. {cmp2}", ax=ax)
sns.scatterplot(data=df, x=f"Cmp. {cmp1}", y=f"Cmp. {cmp2}", ax=ax, color="k")
ax.set(title="Gene Factors")



def plot_pair_cond_factors(X: anndata.AnnData, cmp1: int, cmp2: int, ax: Axes):
def plot_pair_cond_factors(df: pd.DataFrame, cmp1: int, cmp2: int, ax: Axes, label: str):
"""Plots two condition components weights"""
cmpWeights = np.concatenate(
([X.uns["Pf2_A"][:, cmp1 - 1]], [X.uns["Pf2_A"][:, cmp2 - 1]])
)
df = pd.DataFrame(
data=cmpWeights.transpose(), columns=[f"Cmp. {cmp1}", f"Cmp. {cmp2}"]
)
sns.scatterplot(data=df, x=f"Cmp. {cmp1}", y=f"Cmp. {cmp2}", ax=ax)
sns.scatterplot(data=df, x=f"Cmp. {cmp1}", y=f"Cmp. {cmp2}", hue=label, ax=ax)
ax.set(title="Condition Factors")


Expand All @@ -84,7 +80,7 @@ def plot_pair_wp(X: anndata.AnnData, cmp1: int, cmp2: int, ax: Axes, frac: float
)
df = df.sample(frac=frac)

sns.scatterplot(data=df, x=f"Cmp. {cmp1}", y=f"Cmp. {cmp2}", ax=ax)
sns.scatterplot(data=df, x=f"Cmp. {cmp1}", y=f"Cmp. {cmp2}", ax=ax, color="k")
ax.set(title=f"WP {frac*100}% of Cells")


Expand Down
17 changes: 0 additions & 17 deletions pf2/figures/figureBS_A1.py

This file was deleted.

17 changes: 0 additions & 17 deletions pf2/figures/figureBS_S1.py

This file was deleted.

10 changes: 1 addition & 9 deletions pf2/predict.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from typing import Tuple

import pandas as pd
from sklearn.cross_decomposition import PLSRegression
from sklearn.feature_selection import RFECV
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import scale

SKF = StratifiedKFold(n_splits=10)

Expand All @@ -32,15 +28,11 @@ def run_plsr(
if not isinstance(data, pd.DataFrame):
data = pd.DataFrame(data)

data[:] = scale(data)
plsr = PLSRegression(
n_components=n_components,
scale=False,
scale=True,
max_iter=int(1E5)
)
rfe_cv = RFECV(plsr, step=1, cv=SKF, min_features_to_select=n_components)
rfe_cv.fit(data, labels)
data = data.loc[:, rfe_cv.support_]

probabilities = pd.Series(0, dtype=float, index=data.index)
for train_index, test_index in SKF.split(data, labels):
Expand Down