diff --git a/protac_degradation_predictor/__init__.py b/protac_degradation_predictor/__init__.py index f0df9d6..cb53735 100644 --- a/protac_degradation_predictor/__init__.py +++ b/protac_degradation_predictor/__init__.py @@ -28,6 +28,7 @@ from .protac_degradation_predictor import ( get_protac_active_proba, is_protac_active, + get_protac_embedding, ) __version__ = "0.0.1" diff --git a/protac_degradation_predictor/protac_degradation_predictor.py b/protac_degradation_predictor/protac_degradation_predictor.py index 9682a57..6b878cc 100644 --- a/protac_degradation_predictor/protac_degradation_predictor.py +++ b/protac_degradation_predictor/protac_degradation_predictor.py @@ -187,6 +187,8 @@ def get_protac_active_proba( # Average the predictions of all models preds = {} for ckpt_path, model in models.items(): + # Get the last part of the path + ckpt_path = os.path.basename(ckpt_path) if not use_xgboost_models: pred = model( poi_emb if 'aminoacidcnt' not in ckpt_path else poi_aacnt, @@ -198,7 +200,6 @@ def get_protac_active_proba( preds[ckpt_path] = sigmoid(pred).detach().cpu().numpy().flatten() else: X = np.hstack([smiles_emb, poi_emb, e3_emb, cell_emb]) - # pred = model.inplace_predict(X, (model.best_iteration, model.best_iteration)) pred = model.inplace_predict(X) preds[ckpt_path] = pred @@ -257,4 +258,177 @@ def is_protac_active( if use_majority_vote: return pred['majority_vote'] else: - return pred['mean'] > proba_threshold \ No newline at end of file + return pred['mean'] > proba_threshold + + +def get_protac_embedding( + protac_smiles: str | List[str], + e3_ligase: str | List[str], + target_uniprot: str | List[str], + cell_line: str | List[str], + device: Literal['cpu', 'cuda'] = 'cpu', + use_models_from_cv: bool = False, + study_type: Literal['standard', 'similarity', 'target'] = 'standard', +) -> Dict[str, np.ndarray]: + """ Get the embeddings of a PROTAC or a list of PROTACs. + + Args: + protac_smiles (str | List[str]): The SMILES of the PROTAC. + e3_ligase (str | List[str]): The Uniprot ID of the E3 ligase. + target_uniprot (str | List[str]): The Uniprot ID of the target protein. + cell_line (str | List[str]): The cell line identifier. + device (str): The device to run the model on. + use_models_from_cv (bool): Whether to use the models from cross-validation. + study_type (str): Use models trained on the specified study. Options are 'standard', 'similarity', 'target'. + + Returns: + Dict[str, np.ndarray]: The embeddings of the given PROTAC. Each key is the name of the model and the value is the embedding, of shape: (batch_size, model_hidden_size). NOTE: Each model has its own hidden size, so the embeddings might have different dimensions. + """ + # Check that the study type is valid + if study_type not in ['standard', 'similarity', 'target']: + raise ValueError(f"Invalid study type: {study_type}. Options are 'standard', 'similarity', 'target'.") + + # Check that the device is valid + if device not in ['cpu', 'cuda']: + raise ValueError(f"Invalid device: {device}. Options are 'cpu', 'cuda'.") + + # Check that if any the models input is a list, all inputs are lists + model_inputs = [protac_smiles, e3_ligase, target_uniprot, cell_line] + if any(isinstance(i, list) for i in model_inputs): + if not all(isinstance(i, list) for i in model_inputs): + raise ValueError("All model inputs must be lists if one of the inputs is a list.") + + # Load all required models in pkg_resources + device = torch.device(device) + models = {} + model_to_load = 'best_model' if not use_models_from_cv else 'cv_model' + for model_filename in pkg_resources.resource_listdir(__name__, 'models'): + if model_to_load not in model_filename: + continue + if study_type not in model_filename: + continue + if 'xgboost' not in model_filename: + ckpt_path = pkg_resources.resource_filename(__name__, f'models/{model_filename}') + models[ckpt_path] = load_model(ckpt_path).to(device) + + protein2embedding = load_protein2embedding() + cell2embedding = load_cell2embedding() + + # Get the dimension of the embeddings from the first np.array in the dictionary + protein_embedding_size = next(iter(protein2embedding.values())).shape[0] + cell_embedding_size = next(iter(cell2embedding.values())).shape[0] + # Setup default embeddings + default_protein_emb = np.zeros(protein_embedding_size) + default_cell_emb = np.zeros(cell_embedding_size) + + # Check if any model name contains cellsonehot, if so, get onehot encoding + cell2onehot = None + if any('cellsonehot' in model_name for model_name in models.keys()): + onehotenc = OneHotEncoder(sparse_output=False) + cell_embeddings = onehotenc.fit_transform( + np.array(list(cell2embedding.keys())).reshape(-1, 1) + ) + cell2onehot = {k: v for k, v in zip(cell2embedding.keys(), cell_embeddings)} + + # Check if any of the model names contain aminoacidcnt, if so, get the CountVectorizer + protein2aacnt = None + if any('aminoacidcnt' in model_name for model_name in models.keys()): + # Create a new protein2embedding dictionary with amino acid sequence + protac_df = load_curated_dataset() + # Create the dictionary mapping 'Uniprot' to 'POI Sequence' + protein2aacnt = protac_df.set_index('Uniprot')['POI Sequence'].to_dict() + # Create the dictionary mapping 'E3 Ligase Uniprot' to 'E3 Ligase Sequence' + e32seq = protac_df.set_index('E3 Ligase Uniprot')['E3 Ligase Sequence'].to_dict() + # Merge the two dictionaries into a new protein2aacnt dictionary + protein2aacnt.update(e32seq) + + # Get count vectorized embeddings for proteins + # NOTE: Check that the protein2aacnt is a dictionary of strings + if not all(isinstance(k, str) for k in protein2aacnt.keys()): + raise ValueError("All keys in `protein2aacnt` must be strings.") + countvec = CountVectorizer(ngram_range=(1, 1), analyzer='char') + protein_embeddings = countvec.fit_transform( + list(protein2aacnt.keys()) + ).toarray() + protein2aacnt = {k: v for k, v in zip(protein2aacnt.keys(), protein_embeddings)} + + # Convert the E3 ligase to Uniprot ID + if isinstance(e3_ligase, list): + e3_ligase_uniprot = [config.e3_ligase2uniprot.get(e3, '') for e3 in e3_ligase] + else: + e3_ligase_uniprot = config.e3_ligase2uniprot.get(e3_ligase, '') + + # Get the embeddings for the PROTAC, E3 ligase, target protein, and cell line + # Check if the input is a list or a single string, in the latter case, + # convert to a list to create a batch of size 1, len(list) otherwise. + if isinstance(protac_smiles, list): + # TODO: Add warning on missing entries? + smiles_emb = [get_fingerprint(s) for s in protac_smiles] + cell_emb = [cell2embedding.get(c, default_cell_emb) for c in cell_line] + e3_emb = [protein2embedding.get(e3, default_protein_emb) for e3 in e3_ligase_uniprot] + poi_emb = [protein2embedding.get(t, default_protein_emb) for t in target_uniprot] + # Convert to one-hot encoded cell embeddings if necessary + if cell2onehot is not None: + cell_onehot = [cell2onehot.get(c, default_cell_emb) for c in cell_line] + # Convert to amino acid count embeddings if necessary + if protein2aacnt is not None: + poi_aacnt = [protein2aacnt.get(t, default_protein_emb) for t in target_uniprot] + e3_aacnt = [protein2aacnt.get(e3, default_protein_emb) for e3 in e3_ligase_uniprot] + else: + if e3_ligase not in config.e3_ligase2uniprot: + available_e3_ligases = ', '.join(list(config.e3_ligase2uniprot.keys())) + logging.warning(f"The E3 ligase {e3_ligase} is not in the database. Using the default E3 ligase. Available E3 ligases are: {available_e3_ligases}") + if target_uniprot not in protein2embedding: + logging.warning(f"The target protein {target_uniprot} is not in the database. Using the default target protein.") + if cell_line not in cell2embedding: + logging.warning(f"The cell line {cell_line} is not in the database. Using the default cell line.") + smiles_emb = [get_fingerprint(protac_smiles)] + cell_emb = [cell2embedding.get(cell_line, default_cell_emb)] + poi_emb = [protein2embedding.get(target_uniprot, default_protein_emb)] + e3_emb = [protein2embedding.get(e3_ligase_uniprot, default_protein_emb)] + # Convert to one-hot encoded cell embeddings if necessary + if cell2onehot is not None: + cell_onehot = [cell2onehot.get(cell_line, default_cell_emb)] + # Convert to amino acid count embeddings if necessary + if protein2aacnt is not None: + poi_aacnt = [protein2aacnt.get(target_uniprot, default_protein_emb)] + e3_aacnt = [protein2aacnt.get(e3_ligase_uniprot, default_protein_emb)] + + # Convert to numpy arrays + smiles_emb = np.array(smiles_emb) + cell_emb = np.array(cell_emb) + poi_emb = np.array(poi_emb) + e3_emb = np.array(e3_emb) + if cell2onehot is not None: + cell_onehot = np.array(cell_onehot) + if protein2aacnt is not None: + poi_aacnt = np.array(poi_aacnt) + e3_aacnt = np.array(e3_aacnt) + + # Convert to torch tensors + smiles_emb = torch.tensor(smiles_emb).float().to(device) + cell_emb = torch.tensor(cell_emb).to(device) + poi_emb = torch.tensor(poi_emb).to(device) + e3_emb = torch.tensor(e3_emb).to(device) + if cell2onehot is not None: + cell_onehot = torch.tensor(cell_onehot).float().to(device) + if protein2aacnt is not None: + poi_aacnt = torch.tensor(poi_aacnt).float().to(device) + e3_aacnt = torch.tensor(e3_aacnt).float().to(device) + + # Average the predictions of all models + protac_embs = {} + for ckpt_path, model in models.items(): + # Get the last part of the path + ckpt_path = os.path.basename(ckpt_path) + _, protac_emb = model( + poi_emb if 'aminoacidcnt' not in ckpt_path else poi_aacnt, + e3_emb if 'aminoacidcnt' not in ckpt_path else e3_aacnt, + cell_emb if 'cellsonehot' not in ckpt_path else cell_onehot, + smiles_emb, + prescaled_embeddings=False, # Normalization performed by the model + return_embeddings=True, + ) + protac_embs[ckpt_path] = protac_emb + + return protac_embs \ No newline at end of file diff --git a/protac_degradation_predictor/pytorch_models.py b/protac_degradation_predictor/pytorch_models.py index b6f05ca..82d41d0 100644 --- a/protac_degradation_predictor/pytorch_models.py +++ b/protac_degradation_predictor/pytorch_models.py @@ -101,7 +101,7 @@ def __init__( self.dropout = nn.Dropout(p=dropout) - def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb): + def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb, return_embeddings=False): embeddings = [] if self.join_embeddings == 'beginning': # TODO: Remove this if-branch @@ -147,8 +147,10 @@ def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb): if torch.isnan(x).any(): raise ValueError("NaN values found in sum of softmax-ed embeddings.") x = F.relu(self.fc1(x)) - x = self.bnorm(x) if self.use_batch_norm else self.self.dropout(x) - x = self.fc3(x) + h = self.bnorm(x) if self.use_batch_norm else self.self.dropout(x) + x = self.fc3(h) + if return_embeddings: + return x, h return x @@ -277,7 +279,7 @@ def scale_tensor( tensor /= torch.tensor(scaler.scale_, dtype=tensor.dtype, device=tensor.device) + alpha return tensor - def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb, prescaled_embeddings=True): + def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb, prescaled_embeddings=True, return_embeddings=False): if not prescaled_embeddings: if self.apply_scaling: if self.join_embeddings == 'beginning': @@ -302,7 +304,7 @@ def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb, prescaled_embeddings=Tr raise ValueError("NaN values found in cell embeddings.") if torch.isnan(smiles_emb).any(): raise ValueError("NaN values found in SMILES embeddings.") - return self.model(poi_emb, e3_emb, cell_emb, smiles_emb) + return self.model(poi_emb, e3_emb, cell_emb, smiles_emb, return_embeddings) def step(self, batch, batch_idx, stage): poi_emb = batch['poi_emb']