Skip to content

Commit

Permalink
Added PROTAC embedding extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
ribesstefano committed Aug 19, 2024
1 parent 6931e74 commit 54aa7ff
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 7 deletions.
1 change: 1 addition & 0 deletions protac_degradation_predictor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .protac_degradation_predictor import (
get_protac_active_proba,
is_protac_active,
get_protac_embedding,
)

__version__ = "0.0.1"
Expand Down
178 changes: 176 additions & 2 deletions protac_degradation_predictor/protac_degradation_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ def get_protac_active_proba(
# Average the predictions of all models
preds = {}
for ckpt_path, model in models.items():
# Get the last part of the path
ckpt_path = os.path.basename(ckpt_path)
if not use_xgboost_models:
pred = model(
poi_emb if 'aminoacidcnt' not in ckpt_path else poi_aacnt,
Expand All @@ -198,7 +200,6 @@ def get_protac_active_proba(
preds[ckpt_path] = sigmoid(pred).detach().cpu().numpy().flatten()
else:
X = np.hstack([smiles_emb, poi_emb, e3_emb, cell_emb])
# pred = model.inplace_predict(X, (model.best_iteration, model.best_iteration))
pred = model.inplace_predict(X)
preds[ckpt_path] = pred

Expand Down Expand Up @@ -257,4 +258,177 @@ def is_protac_active(
if use_majority_vote:
return pred['majority_vote']
else:
return pred['mean'] > proba_threshold
return pred['mean'] > proba_threshold


def get_protac_embedding(
protac_smiles: str | List[str],
e3_ligase: str | List[str],
target_uniprot: str | List[str],
cell_line: str | List[str],
device: Literal['cpu', 'cuda'] = 'cpu',
use_models_from_cv: bool = False,
study_type: Literal['standard', 'similarity', 'target'] = 'standard',
) -> Dict[str, np.ndarray]:
""" Get the embeddings of a PROTAC or a list of PROTACs.
Args:
protac_smiles (str | List[str]): The SMILES of the PROTAC.
e3_ligase (str | List[str]): The Uniprot ID of the E3 ligase.
target_uniprot (str | List[str]): The Uniprot ID of the target protein.
cell_line (str | List[str]): The cell line identifier.
device (str): The device to run the model on.
use_models_from_cv (bool): Whether to use the models from cross-validation.
study_type (str): Use models trained on the specified study. Options are 'standard', 'similarity', 'target'.
Returns:
Dict[str, np.ndarray]: The embeddings of the given PROTAC. Each key is the name of the model and the value is the embedding, of shape: (batch_size, model_hidden_size). NOTE: Each model has its own hidden size, so the embeddings might have different dimensions.
"""
# Check that the study type is valid
if study_type not in ['standard', 'similarity', 'target']:
raise ValueError(f"Invalid study type: {study_type}. Options are 'standard', 'similarity', 'target'.")

# Check that the device is valid
if device not in ['cpu', 'cuda']:
raise ValueError(f"Invalid device: {device}. Options are 'cpu', 'cuda'.")

# Check that if any the models input is a list, all inputs are lists
model_inputs = [protac_smiles, e3_ligase, target_uniprot, cell_line]
if any(isinstance(i, list) for i in model_inputs):
if not all(isinstance(i, list) for i in model_inputs):
raise ValueError("All model inputs must be lists if one of the inputs is a list.")

# Load all required models in pkg_resources
device = torch.device(device)
models = {}
model_to_load = 'best_model' if not use_models_from_cv else 'cv_model'
for model_filename in pkg_resources.resource_listdir(__name__, 'models'):
if model_to_load not in model_filename:
continue
if study_type not in model_filename:
continue
if 'xgboost' not in model_filename:
ckpt_path = pkg_resources.resource_filename(__name__, f'models/{model_filename}')
models[ckpt_path] = load_model(ckpt_path).to(device)

protein2embedding = load_protein2embedding()
cell2embedding = load_cell2embedding()

# Get the dimension of the embeddings from the first np.array in the dictionary
protein_embedding_size = next(iter(protein2embedding.values())).shape[0]
cell_embedding_size = next(iter(cell2embedding.values())).shape[0]
# Setup default embeddings
default_protein_emb = np.zeros(protein_embedding_size)
default_cell_emb = np.zeros(cell_embedding_size)

# Check if any model name contains cellsonehot, if so, get onehot encoding
cell2onehot = None
if any('cellsonehot' in model_name for model_name in models.keys()):
onehotenc = OneHotEncoder(sparse_output=False)
cell_embeddings = onehotenc.fit_transform(
np.array(list(cell2embedding.keys())).reshape(-1, 1)
)
cell2onehot = {k: v for k, v in zip(cell2embedding.keys(), cell_embeddings)}

# Check if any of the model names contain aminoacidcnt, if so, get the CountVectorizer
protein2aacnt = None
if any('aminoacidcnt' in model_name for model_name in models.keys()):
# Create a new protein2embedding dictionary with amino acid sequence
protac_df = load_curated_dataset()
# Create the dictionary mapping 'Uniprot' to 'POI Sequence'
protein2aacnt = protac_df.set_index('Uniprot')['POI Sequence'].to_dict()
# Create the dictionary mapping 'E3 Ligase Uniprot' to 'E3 Ligase Sequence'
e32seq = protac_df.set_index('E3 Ligase Uniprot')['E3 Ligase Sequence'].to_dict()
# Merge the two dictionaries into a new protein2aacnt dictionary
protein2aacnt.update(e32seq)

# Get count vectorized embeddings for proteins
# NOTE: Check that the protein2aacnt is a dictionary of strings
if not all(isinstance(k, str) for k in protein2aacnt.keys()):
raise ValueError("All keys in `protein2aacnt` must be strings.")
countvec = CountVectorizer(ngram_range=(1, 1), analyzer='char')
protein_embeddings = countvec.fit_transform(
list(protein2aacnt.keys())
).toarray()
protein2aacnt = {k: v for k, v in zip(protein2aacnt.keys(), protein_embeddings)}

# Convert the E3 ligase to Uniprot ID
if isinstance(e3_ligase, list):
e3_ligase_uniprot = [config.e3_ligase2uniprot.get(e3, '') for e3 in e3_ligase]
else:
e3_ligase_uniprot = config.e3_ligase2uniprot.get(e3_ligase, '')

# Get the embeddings for the PROTAC, E3 ligase, target protein, and cell line
# Check if the input is a list or a single string, in the latter case,
# convert to a list to create a batch of size 1, len(list) otherwise.
if isinstance(protac_smiles, list):
# TODO: Add warning on missing entries?
smiles_emb = [get_fingerprint(s) for s in protac_smiles]
cell_emb = [cell2embedding.get(c, default_cell_emb) for c in cell_line]
e3_emb = [protein2embedding.get(e3, default_protein_emb) for e3 in e3_ligase_uniprot]
poi_emb = [protein2embedding.get(t, default_protein_emb) for t in target_uniprot]
# Convert to one-hot encoded cell embeddings if necessary
if cell2onehot is not None:
cell_onehot = [cell2onehot.get(c, default_cell_emb) for c in cell_line]
# Convert to amino acid count embeddings if necessary
if protein2aacnt is not None:
poi_aacnt = [protein2aacnt.get(t, default_protein_emb) for t in target_uniprot]
e3_aacnt = [protein2aacnt.get(e3, default_protein_emb) for e3 in e3_ligase_uniprot]
else:
if e3_ligase not in config.e3_ligase2uniprot:
available_e3_ligases = ', '.join(list(config.e3_ligase2uniprot.keys()))
logging.warning(f"The E3 ligase {e3_ligase} is not in the database. Using the default E3 ligase. Available E3 ligases are: {available_e3_ligases}")
if target_uniprot not in protein2embedding:
logging.warning(f"The target protein {target_uniprot} is not in the database. Using the default target protein.")
if cell_line not in cell2embedding:
logging.warning(f"The cell line {cell_line} is not in the database. Using the default cell line.")
smiles_emb = [get_fingerprint(protac_smiles)]
cell_emb = [cell2embedding.get(cell_line, default_cell_emb)]
poi_emb = [protein2embedding.get(target_uniprot, default_protein_emb)]
e3_emb = [protein2embedding.get(e3_ligase_uniprot, default_protein_emb)]
# Convert to one-hot encoded cell embeddings if necessary
if cell2onehot is not None:
cell_onehot = [cell2onehot.get(cell_line, default_cell_emb)]
# Convert to amino acid count embeddings if necessary
if protein2aacnt is not None:
poi_aacnt = [protein2aacnt.get(target_uniprot, default_protein_emb)]
e3_aacnt = [protein2aacnt.get(e3_ligase_uniprot, default_protein_emb)]

# Convert to numpy arrays
smiles_emb = np.array(smiles_emb)
cell_emb = np.array(cell_emb)
poi_emb = np.array(poi_emb)
e3_emb = np.array(e3_emb)
if cell2onehot is not None:
cell_onehot = np.array(cell_onehot)
if protein2aacnt is not None:
poi_aacnt = np.array(poi_aacnt)
e3_aacnt = np.array(e3_aacnt)

# Convert to torch tensors
smiles_emb = torch.tensor(smiles_emb).float().to(device)
cell_emb = torch.tensor(cell_emb).to(device)
poi_emb = torch.tensor(poi_emb).to(device)
e3_emb = torch.tensor(e3_emb).to(device)
if cell2onehot is not None:
cell_onehot = torch.tensor(cell_onehot).float().to(device)
if protein2aacnt is not None:
poi_aacnt = torch.tensor(poi_aacnt).float().to(device)
e3_aacnt = torch.tensor(e3_aacnt).float().to(device)

# Average the predictions of all models
protac_embs = {}
for ckpt_path, model in models.items():
# Get the last part of the path
ckpt_path = os.path.basename(ckpt_path)
_, protac_emb = model(
poi_emb if 'aminoacidcnt' not in ckpt_path else poi_aacnt,
e3_emb if 'aminoacidcnt' not in ckpt_path else e3_aacnt,
cell_emb if 'cellsonehot' not in ckpt_path else cell_onehot,
smiles_emb,
prescaled_embeddings=False, # Normalization performed by the model
return_embeddings=True,
)
protac_embs[ckpt_path] = protac_emb

return protac_embs
12 changes: 7 additions & 5 deletions protac_degradation_predictor/pytorch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
self.dropout = nn.Dropout(p=dropout)


def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb, return_embeddings=False):
embeddings = []
if self.join_embeddings == 'beginning':
# TODO: Remove this if-branch
Expand Down Expand Up @@ -147,8 +147,10 @@ def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
if torch.isnan(x).any():
raise ValueError("NaN values found in sum of softmax-ed embeddings.")
x = F.relu(self.fc1(x))
x = self.bnorm(x) if self.use_batch_norm else self.self.dropout(x)
x = self.fc3(x)
h = self.bnorm(x) if self.use_batch_norm else self.self.dropout(x)
x = self.fc3(h)
if return_embeddings:
return x, h
return x


Expand Down Expand Up @@ -277,7 +279,7 @@ def scale_tensor(
tensor /= torch.tensor(scaler.scale_, dtype=tensor.dtype, device=tensor.device) + alpha
return tensor

def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb, prescaled_embeddings=True):
def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb, prescaled_embeddings=True, return_embeddings=False):
if not prescaled_embeddings:
if self.apply_scaling:
if self.join_embeddings == 'beginning':
Expand All @@ -302,7 +304,7 @@ def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb, prescaled_embeddings=Tr
raise ValueError("NaN values found in cell embeddings.")
if torch.isnan(smiles_emb).any():
raise ValueError("NaN values found in SMILES embeddings.")
return self.model(poi_emb, e3_emb, cell_emb, smiles_emb)
return self.model(poi_emb, e3_emb, cell_emb, smiles_emb, return_embeddings)

def step(self, batch, batch_idx, stage):
poi_emb = batch['poi_emb']
Expand Down

0 comments on commit 54aa7ff

Please sign in to comment.