diff --git a/app/run_utils.py b/app/run_utils.py index e61be444e..a0e34a5f4 100644 --- a/app/run_utils.py +++ b/app/run_utils.py @@ -44,7 +44,7 @@ def configure_logging(level=None): logging.basicConfig( level=level, format="[%(asctime)s] [%(filename)s:%(lineno)d] %(levelname)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", + datefmt="%Y-%m-%d %H:%M:%S %Z", handlers=[ logging.StreamHandler(), # Outputs logs to stderr by default # If you also want to log to a file, uncomment the following line: @@ -93,8 +93,10 @@ def run_cli_command( assert len(args) == len(ARG_ORDER), f'Expected {len(ARG_ORDER)} arguments, got {len(args)}' + inference_log_level = os.environ.get("INFERENCE_LOG_LEVEL", os.environ.get("LOG_LEVEL", "WARNING")) + all_arg_dict = {"protein_path": protein_path, "ligand": ligand, "config": config_path, - "no_final_step_noise": True} + "no_final_step_noise": True, "loglevel": inference_log_level} for arg_name, arg_val in zip(ARG_ORDER, args): all_arg_dict[arg_name] = arg_val @@ -136,12 +138,18 @@ def run_cli_command( capture_output=True, ) logging.debug(f"Command output:\n{result.stdout}") + full_output = f"Standard out:\n{result.stdout}" if result.stderr: # Skip progress bar lines stderr_lines = result.stderr.split("\n") stderr_lines = filter(lambda x: "%|" not in x, stderr_lines) stderr_text = "\n".join(stderr_lines) logging.error(f"Command error:\n{stderr_text}") + full_output += f"\nStandard error:\n{stderr_text}" + + with open(f"{temp_dir_path}/output.log", "w") as log_file: + log_file.write(full_output) + else: logging.debug("Skipping command execution") artificial_output_dir = os.path.join(TEMP_DIR, "artificial_output") diff --git a/datasets/process_mols.py b/datasets/process_mols.py index 9fd7ec28a..92b9c63ac 100644 --- a/datasets/process_mols.py +++ b/datasets/process_mols.py @@ -17,6 +17,7 @@ from datasets.constants import aa_short2long, atom_order, three_to_one from datasets.parse_chi import get_chi_angles, get_coords, aa_idx2aa_short, get_onehot_sequence from utils.torsion import get_transformation_mask +from utils.logging_utils import get_logger periodic_table = GetPeriodicTable() @@ -305,11 +306,11 @@ def generate_conformer(mol): failures, id = 0, -1 while failures < 3 and id == -1: if failures > 0: - print(f'rdkit coords could not be generated. trying again {failures}.') + get_logger().debug(f'rdkit coords could not be generated. trying again {failures}.') id = AllChem.EmbedMolecule(mol, ps) failures += 1 if id == -1: - print('rdkit coords could not be generated without using random coords. using random coords now.') + get_logger().info('rdkit coords could not be generated without using random coords. using random coords now.') ps.useRandomCoords = True AllChem.EmbedMolecule(mol, ps) AllChem.MMFFOptimizeMolecule(mol, confId=0) @@ -417,6 +418,7 @@ def write_mol_with_coords(mol, new_coords, path): w.write(mol) w.close() + def read_molecule(molecule_file, sanitize=False, calc_charges=False, remove_hs=False): if molecule_file.endswith('.mol2'): mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False) @@ -433,8 +435,8 @@ def read_molecule(molecule_file, sanitize=False, calc_charges=False, remove_hs=F elif molecule_file.endswith('.pdb'): mol = Chem.MolFromPDBFile(molecule_file, sanitize=False, removeHs=False) else: - return ValueError('Expect the format of the molecule_file to be ' - 'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file)) + raise ValueError('Expect the format of the molecule_file to be ' + 'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file)) try: if sanitize or calc_charges: @@ -449,7 +451,12 @@ def read_molecule(molecule_file, sanitize=False, calc_charges=False, remove_hs=F if remove_hs: mol = Chem.RemoveHs(mol, sanitize=sanitize) - except: + + except Exception as e: + # Print stacktrace + import traceback + msg = traceback.format_exc() + get_logger().warning(f"Failed to process molecule: {molecule_file}\n{msg}") return None return mol diff --git a/inference.py b/inference.py index 938813b5b..c56941f01 100644 --- a/inference.py +++ b/inference.py @@ -1,14 +1,31 @@ +import functools +import logging +import pprint +import traceback +from argparse import ArgumentParser, Namespace, FileType import copy import os -import torch -from argparse import ArgumentParser, Namespace, FileType from functools import partial +import warnings +from typing import Mapping, Optional + +import yaml + +# Ignore pandas deprecation warning around pyarrow +warnings.filterwarnings("ignore", category=DeprecationWarning, + message="(?s).*Pyarrow will become a required dependency of pandas.*") + import numpy as np import pandas as pd -from rdkit import RDLogger +import torch from torch_geometric.loader import DataLoader + +from rdkit import RDLogger from rdkit.Chem import RemoveAllHs +# TODO imports are a little odd, utils seems to shadow things +from utils.logging_utils import configure_logger, get_logger +import utils.utils from datasets.process_mols import write_mol_with_coords from utils.download import download_and_extract from utils.diffusion_utils import t_to_sigma as t_to_sigma_compl, get_t_schedule @@ -18,249 +35,284 @@ from utils.visualise import PDBFile from tqdm import tqdm +if os.name != 'nt': # The line does not work on Windows + import resource + rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) + resource.setrlimit(resource.RLIMIT_NOFILE, (64000, rlimit[1])) + RDLogger.DisableLog('rdApp.*') -import yaml -parser = ArgumentParser() -parser.add_argument('--config', type=FileType(mode='r'), default='default_inference_args.yaml') -parser.add_argument('--protein_ligand_csv', type=str, default=None, help='Path to a .csv file specifying the input as described in the README. If this is not None, it will be used instead of the --protein_path, --protein_sequence and --ligand parameters') -parser.add_argument('--complex_name', type=str, default=None, help='Name that the complex will be saved with') -parser.add_argument('--protein_path', type=str, default=None, help='Path to the protein file') -parser.add_argument('--protein_sequence', type=str, default=None, help='Sequence of the protein for ESMFold, this is ignored if --protein_path is not None') -parser.add_argument('--ligand_description', type=str, default='CCCCC(NC(=O)CCC(=O)O)P(=O)(O)OC1=CC=CC=C1', help='Either a SMILES string or the path to a molecule file that rdkit can read') - -parser.add_argument('--out_dir', type=str, default='results/user_inference', help='Directory where the outputs will be written to') -parser.add_argument('--save_visualisation', action='store_true', default=False, help='Save a pdb file with all of the steps of the reverse diffusion') -parser.add_argument('--samples_per_complex', type=int, default=10, help='Number of samples to generate') - -parser.add_argument('--model_dir', type=str, default=None, help='Path to folder with trained score model and hyperparameters') -parser.add_argument('--ckpt', type=str, default='best_ema_inference_epoch_model.pt', help='Checkpoint to use for the score model') -parser.add_argument('--confidence_model_dir', type=str, default=None, help='Path to folder with trained confidence model and hyperparameters') -parser.add_argument('--confidence_ckpt', type=str, default='best_model.pt', help='Checkpoint to use for the confidence model') - -parser.add_argument('--batch_size', type=int, default=10, help='') -parser.add_argument('--no_final_step_noise', action='store_true', default=True, help='Use no noise in the final step of the reverse diffusion') -parser.add_argument('--inference_steps', type=int, default=20, help='Number of denoising steps') -parser.add_argument('--actual_steps', type=int, default=None, help='Number of denoising steps that are actually performed') - -parser.add_argument('--old_score_model', action='store_true', default=False, help='') -parser.add_argument('--old_confidence_model', action='store_true', default=True, help='') -parser.add_argument('--initial_noise_std_proportion', type=float, default=-1.0, help='Initial noise std proportion') -parser.add_argument('--choose_residue', action='store_true', default=False, help='') - -parser.add_argument('--temp_sampling_tr', type=float, default=1.0) -parser.add_argument('--temp_psi_tr', type=float, default=0.0) -parser.add_argument('--temp_sigma_data_tr', type=float, default=0.5) -parser.add_argument('--temp_sampling_rot', type=float, default=1.0) -parser.add_argument('--temp_psi_rot', type=float, default=0.0) -parser.add_argument('--temp_sigma_data_rot', type=float, default=0.5) -parser.add_argument('--temp_sampling_tor', type=float, default=1.0) -parser.add_argument('--temp_psi_tor', type=float, default=0.0) -parser.add_argument('--temp_sigma_data_tor', type=float, default=0.5) - -parser.add_argument('--gnina_minimize', action='store_true', default=False, help='') -parser.add_argument('--gnina_path', type=str, default='gnina', help='') -parser.add_argument('--gnina_log_file', type=str, default='gnina_log.txt', help='') # To redirect gnina subprocesses stdouts from the terminal window -parser.add_argument('--gnina_full_dock', action='store_true', default=False, help='') -parser.add_argument('--gnina_autobox_add', type=float, default=4.0) -parser.add_argument('--gnina_poses_to_optimize', type=int, default=1) - -args = parser.parse_args() -REPOSITORY_URL = os.environ.get("REPOSITORY_URL", "https://github.com/gcorso/DiffDock") +warnings.filterwarnings("ignore", category=UserWarning, + message="The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`") + +# Prody logging is very verbose by default +prody_logger = logging.getLogger(".prody") +prody_logger.setLevel(logging.ERROR) -if args.config: - config_dict = yaml.load(args.config, Loader=yaml.FullLoader) - arg_dict = args.__dict__ - for key, value in config_dict.items(): - if isinstance(value, list): - for v in value: - arg_dict[key].append(v) - else: - arg_dict[key] = value - -# Download models if they don't exist locally -if not os.path.exists(args.model_dir): - print(f"Models not found. Downloading") - # TODO Remove the dropbox URL once the models are uploaded to GitHub release - remote_urls = [f"{REPOSITORY_URL}/releases/latest/download/diffdock_models.zip", - "https://www.dropbox.com/scl/fi/drg90rst8uhd2633tyou0/diffdock_models.zip?rlkey=afzq4kuqor2jb8adah41ro2lz&dl=1"] - downloaded_successfully = False - for remote_url in remote_urls: +REPOSITORY_URL = os.environ.get("REPOSITORY_URL", "https://github.com/gcorso/DiffDock") +REMOTE_URLS = [f"{REPOSITORY_URL}/releases/latest/download/diffdock_models.zip", + f"{REPOSITORY_URL}/releases/download/v1.1/diffdock_models.zip"] + + +def get_parser(): + parser = ArgumentParser() + parser.add_argument('--config', type=FileType(mode='r'), default='default_inference_args.yaml') + parser.add_argument('--protein_ligand_csv', type=str, default=None, help='Path to a .csv file specifying the input as described in the README. If this is not None, it will be used instead of the --protein_path, --protein_sequence and --ligand parameters') + parser.add_argument('--complex_name', type=str, default=None, help='Name that the complex will be saved with') + parser.add_argument('--protein_path', type=str, default=None, help='Path to the protein file') + parser.add_argument('--protein_sequence', type=str, default=None, help='Sequence of the protein for ESMFold, this is ignored if --protein_path is not None') + parser.add_argument('--ligand_description', type=str, default='CCCCC(NC(=O)CCC(=O)O)P(=O)(O)OC1=CC=CC=C1', help='Either a SMILES string or the path to a molecule file that rdkit can read') + + parser.add_argument('-l', '--log', '--loglevel', type=str, default='WARNING', dest="loglevel", + help='Log level. Default %(default)s') + + parser.add_argument('--out_dir', type=str, default='results/user_inference', help='Directory where the outputs will be written to') + parser.add_argument('--save_visualisation', action='store_true', default=False, help='Save a pdb file with all of the steps of the reverse diffusion') + parser.add_argument('--samples_per_complex', type=int, default=10, help='Number of samples to generate') + + parser.add_argument('--model_dir', type=str, default=None, help='Path to folder with trained score model and hyperparameters') + parser.add_argument('--ckpt', type=str, default='best_ema_inference_epoch_model.pt', help='Checkpoint to use for the score model') + parser.add_argument('--confidence_model_dir', type=str, default=None, help='Path to folder with trained confidence model and hyperparameters') + parser.add_argument('--confidence_ckpt', type=str, default='best_model.pt', help='Checkpoint to use for the confidence model') + + parser.add_argument('--batch_size', type=int, default=10, help='') + parser.add_argument('--no_final_step_noise', action='store_true', default=True, help='Use no noise in the final step of the reverse diffusion') + parser.add_argument('--inference_steps', type=int, default=20, help='Number of denoising steps') + parser.add_argument('--actual_steps', type=int, default=None, help='Number of denoising steps that are actually performed') + + parser.add_argument('--old_score_model', action='store_true', default=False, help='') + parser.add_argument('--old_confidence_model', action='store_true', default=True, help='') + parser.add_argument('--initial_noise_std_proportion', type=float, default=-1.0, help='Initial noise std proportion') + parser.add_argument('--choose_residue', action='store_true', default=False, help='') + + parser.add_argument('--temp_sampling_tr', type=float, default=1.0) + parser.add_argument('--temp_psi_tr', type=float, default=0.0) + parser.add_argument('--temp_sigma_data_tr', type=float, default=0.5) + parser.add_argument('--temp_sampling_rot', type=float, default=1.0) + parser.add_argument('--temp_psi_rot', type=float, default=0.0) + parser.add_argument('--temp_sigma_data_rot', type=float, default=0.5) + parser.add_argument('--temp_sampling_tor', type=float, default=1.0) + parser.add_argument('--temp_psi_tor', type=float, default=0.0) + parser.add_argument('--temp_sigma_data_tor', type=float, default=0.5) + + parser.add_argument('--gnina_minimize', action='store_true', default=False, help='') + parser.add_argument('--gnina_path', type=str, default='gnina', help='') + parser.add_argument('--gnina_log_file', type=str, default='gnina_log.txt', help='') # To redirect gnina subprocesses stdouts from the terminal window + parser.add_argument('--gnina_full_dock', action='store_true', default=False, help='') + parser.add_argument('--gnina_autobox_add', type=float, default=4.0) + parser.add_argument('--gnina_poses_to_optimize', type=int, default=1) + + return parser + + +def main(args): + + configure_logger(args.loglevel) + logger = get_logger() + + if args.config: + config_dict = yaml.load(args.config, Loader=yaml.FullLoader) + arg_dict = args.__dict__ + for key, value in config_dict.items(): + if isinstance(value, list): + for v in value: + arg_dict[key].append(v) + else: + arg_dict[key] = value + + # Download models if they don't exist locally + if not os.path.exists(args.model_dir): + logger.info(f"Models not found. Downloading") + remote_urls = REMOTE_URLS + downloaded_successfully = False + for remote_url in remote_urls: + try: + logger.info(f"Attempting download from {remote_url}") + files_downloaded = download_and_extract(remote_url, os.path.dirname(args.model_dir)) + if not files_downloaded: + logger.info(f"Download from {remote_url} failed.") + continue + logger.info(f"Downloaded and extracted {len(files_downloaded)} files from {remote_url}") + downloaded_successfully = True + # Once we have downloaded the models, we can break the loop + break + except Exception as e: + pass + + if not downloaded_successfully: + raise Exception(f"Models not found locally and failed to download them from {remote_urls}") + + os.makedirs(args.out_dir, exist_ok=True) + with open(f'{args.model_dir}/model_parameters.yml') as f: + score_model_args = Namespace(**yaml.full_load(f)) + if args.confidence_model_dir is not None: + with open(f'{args.confidence_model_dir}/model_parameters.yml') as f: + confidence_args = Namespace(**yaml.full_load(f)) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + logger.info(f"DiffDock will run on {device}") + + if args.protein_ligand_csv is not None: + df = pd.read_csv(args.protein_ligand_csv) + complex_name_list = set_nones(df['complex_name'].tolist()) + protein_path_list = set_nones(df['protein_path'].tolist()) + protein_sequence_list = set_nones(df['protein_sequence'].tolist()) + ligand_description_list = set_nones(df['ligand_description'].tolist()) + else: + complex_name_list = [args.complex_name if args.complex_name else f"complex_0"] + protein_path_list = [args.protein_path] + protein_sequence_list = [args.protein_sequence] + ligand_description_list = [args.ligand_description] + + complex_name_list = [name if name is not None else f"complex_{i}" for i, name in enumerate(complex_name_list)] + for name in complex_name_list: + write_dir = f'{args.out_dir}/{name}' + os.makedirs(write_dir, exist_ok=True) + + # preprocessing of complexes into geometric graphs + test_dataset = InferenceDataset(out_dir=args.out_dir, complex_names=complex_name_list, protein_files=protein_path_list, + ligand_descriptions=ligand_description_list, protein_sequences=protein_sequence_list, + lm_embeddings=True, + receptor_radius=score_model_args.receptor_radius, remove_hs=score_model_args.remove_hs, + c_alpha_max_neighbors=score_model_args.c_alpha_max_neighbors, + all_atoms=score_model_args.all_atoms, atom_radius=score_model_args.atom_radius, + atom_max_neighbors=score_model_args.atom_max_neighbors, + knn_only_graph=False if not hasattr(score_model_args, 'not_knn_only_graph') else not score_model_args.not_knn_only_graph) + test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) + + if args.confidence_model_dir is not None and not confidence_args.use_original_model_cache: + logger.info('Confidence model uses different type of graphs than the score model. ' + 'Loading (or creating if not existing) the data for the confidence model now.') + confidence_test_dataset = \ + InferenceDataset(out_dir=args.out_dir, complex_names=complex_name_list, protein_files=protein_path_list, + ligand_descriptions=ligand_description_list, protein_sequences=protein_sequence_list, + lm_embeddings=True, + receptor_radius=confidence_args.receptor_radius, remove_hs=confidence_args.remove_hs, + c_alpha_max_neighbors=confidence_args.c_alpha_max_neighbors, + all_atoms=confidence_args.all_atoms, atom_radius=confidence_args.atom_radius, + atom_max_neighbors=confidence_args.atom_max_neighbors, + precomputed_lm_embeddings=test_dataset.lm_embeddings, + knn_only_graph=False if not hasattr(score_model_args, 'not_knn_only_graph') else not score_model_args.not_knn_only_graph) + else: + confidence_test_dataset = None + + t_to_sigma = partial(t_to_sigma_compl, args=score_model_args) + + model = get_model(score_model_args, device, t_to_sigma=t_to_sigma, no_parallel=True, old=args.old_score_model) + state_dict = torch.load(f'{args.model_dir}/{args.ckpt}', map_location=torch.device('cpu')) + model.load_state_dict(state_dict, strict=True) + model = model.to(device) + model.eval() + + if args.confidence_model_dir is not None: + confidence_model = get_model(confidence_args, device, t_to_sigma=t_to_sigma, no_parallel=True, + confidence_mode=True, old=args.old_confidence_model) + state_dict = torch.load(f'{args.confidence_model_dir}/{args.confidence_ckpt}', map_location=torch.device('cpu')) + confidence_model.load_state_dict(state_dict, strict=True) + confidence_model = confidence_model.to(device) + confidence_model.eval() + else: + confidence_model = None + confidence_args = None + + tr_schedule = get_t_schedule(inference_steps=args.inference_steps, sigma_schedule='expbeta') + + failures, skipped = 0, 0 + N = args.samples_per_complex + test_ds_size = len(test_dataset) + logger.info(f'Size of test dataset: {test_ds_size}') + for idx, orig_complex_graph in tqdm(enumerate(test_loader)): + if not orig_complex_graph.success[0]: + skipped += 1 + logger.warning(f"The test dataset did not contain {test_dataset.complex_names[idx]} for {test_dataset.ligand_descriptions[idx]} and {test_dataset.protein_files[idx]}. We are skipping this complex.") + continue try: - print(f"Attempting download from {remote_url}") - files_downloaded = download_and_extract(remote_url, os.path.dirname(args.model_dir)) - if not files_downloaded: - print(f"Download from {remote_url} failed.") - continue - print(f"Downloaded and extracted {len(files_downloaded)} files from {remote_url}") - downloaded_successfully = True - # Once we have downloaded the models, we can break the loop - break - except Exception as e: - pass - - if not downloaded_successfully: - raise Exception(f"Models not found locally and failed to download them from {remote_urls}") - -os.makedirs(args.out_dir, exist_ok=True) -with open(f'{args.model_dir}/model_parameters.yml') as f: - score_model_args = Namespace(**yaml.full_load(f)) -if args.confidence_model_dir is not None: - with open(f'{args.confidence_model_dir}/model_parameters.yml') as f: - confidence_args = Namespace(**yaml.full_load(f)) - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -print(f"DiffDock will run on {device}") - -if args.protein_ligand_csv is not None: - df = pd.read_csv(args.protein_ligand_csv) - complex_name_list = set_nones(df['complex_name'].tolist()) - protein_path_list = set_nones(df['protein_path'].tolist()) - protein_sequence_list = set_nones(df['protein_sequence'].tolist()) - ligand_description_list = set_nones(df['ligand_description'].tolist()) -else: - complex_name_list = [args.complex_name if args.complex_name else f"complex_0"] - protein_path_list = [args.protein_path] - protein_sequence_list = [args.protein_sequence] - ligand_description_list = [args.ligand_description] - -complex_name_list = [name if name is not None else f"complex_{i}" for i, name in enumerate(complex_name_list)] -for name in complex_name_list: - write_dir = f'{args.out_dir}/{name}' - os.makedirs(write_dir, exist_ok=True) - -# preprocessing of complexes into geometric graphs -test_dataset = InferenceDataset(out_dir=args.out_dir, complex_names=complex_name_list, protein_files=protein_path_list, - ligand_descriptions=ligand_description_list, protein_sequences=protein_sequence_list, - lm_embeddings=True, - receptor_radius=score_model_args.receptor_radius, remove_hs=score_model_args.remove_hs, - c_alpha_max_neighbors=score_model_args.c_alpha_max_neighbors, - all_atoms=score_model_args.all_atoms, atom_radius=score_model_args.atom_radius, - atom_max_neighbors=score_model_args.atom_max_neighbors, - knn_only_graph=False if not hasattr(score_model_args, 'not_knn_only_graph') else not score_model_args.not_knn_only_graph) -test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) - -if args.confidence_model_dir is not None and not confidence_args.use_original_model_cache: - print('HAPPENING | confidence model uses different type of graphs than the score model. ' - 'Loading (or creating if not existing) the data for the confidence model now.') - confidence_test_dataset = \ - InferenceDataset(out_dir=args.out_dir, complex_names=complex_name_list, protein_files=protein_path_list, - ligand_descriptions=ligand_description_list, protein_sequences=protein_sequence_list, - lm_embeddings=True, - receptor_radius=confidence_args.receptor_radius, remove_hs=confidence_args.remove_hs, - c_alpha_max_neighbors=confidence_args.c_alpha_max_neighbors, - all_atoms=confidence_args.all_atoms, atom_radius=confidence_args.atom_radius, - atom_max_neighbors=confidence_args.atom_max_neighbors, - precomputed_lm_embeddings=test_dataset.lm_embeddings, - knn_only_graph=False if not hasattr(score_model_args, 'not_knn_only_graph') else not score_model_args.not_knn_only_graph) -else: - confidence_test_dataset = None - -t_to_sigma = partial(t_to_sigma_compl, args=score_model_args) - -model = get_model(score_model_args, device, t_to_sigma=t_to_sigma, no_parallel=True, old=args.old_score_model) -state_dict = torch.load(f'{args.model_dir}/{args.ckpt}', map_location=torch.device('cpu')) -model.load_state_dict(state_dict, strict=True) -model = model.to(device) -model.eval() - -if args.confidence_model_dir is not None: - confidence_model = get_model(confidence_args, device, t_to_sigma=t_to_sigma, no_parallel=True, - confidence_mode=True, old=args.old_confidence_model) - state_dict = torch.load(f'{args.confidence_model_dir}/{args.confidence_ckpt}', map_location=torch.device('cpu')) - confidence_model.load_state_dict(state_dict, strict=True) - confidence_model = confidence_model.to(device) - confidence_model.eval() -else: - confidence_model = None - confidence_args = None - -tr_schedule = get_t_schedule(inference_steps=args.inference_steps, sigma_schedule='expbeta') - -failures, skipped = 0, 0 -N = args.samples_per_complex -print('Size of test dataset: ', len(test_dataset)) -for idx, orig_complex_graph in tqdm(enumerate(test_loader)): - if not orig_complex_graph.success[0]: - skipped += 1 - print(f"HAPPENING | The test dataset did not contain {test_dataset.complex_names[idx]} for {test_dataset.ligand_descriptions[idx]} and {test_dataset.protein_files[idx]}. We are skipping this complex.") - continue - try: - if confidence_test_dataset is not None: - confidence_complex_graph = confidence_test_dataset[idx] - if not confidence_complex_graph.success: - skipped += 1 - print(f"HAPPENING | The confidence dataset did not contain {orig_complex_graph.name}. We are skipping this complex.") - continue - confidence_data_list = [copy.deepcopy(confidence_complex_graph) for _ in range(N)] - else: - confidence_data_list = None - data_list = [copy.deepcopy(orig_complex_graph) for _ in range(N)] - randomize_position(data_list, score_model_args.no_torsion, False, score_model_args.tr_sigma_max, - initial_noise_std_proportion=args.initial_noise_std_proportion, - choose_residue=args.choose_residue) - - lig = orig_complex_graph.mol[0] - - # initialize visualisation - pdb = None - if args.save_visualisation: - visualization_list = [] - for graph in data_list: - pdb = PDBFile(lig) - pdb.add(lig, 0, 0) - pdb.add((orig_complex_graph['ligand'].pos + orig_complex_graph.original_center).detach().cpu(), 1, 0) - pdb.add((graph['ligand'].pos + graph.original_center).detach().cpu(), part=1, order=1) - visualization_list.append(pdb) - else: - visualization_list = None - - # run reverse diffusion - data_list, confidence = sampling(data_list=data_list, model=model, - inference_steps=args.actual_steps if args.actual_steps is not None else args.inference_steps, - tr_schedule=tr_schedule, rot_schedule=tr_schedule, tor_schedule=tr_schedule, - device=device, t_to_sigma=t_to_sigma, model_args=score_model_args, - visualization_list=visualization_list, confidence_model=confidence_model, - confidence_data_list=confidence_data_list, confidence_model_args=confidence_args, - batch_size=args.batch_size, no_final_step_noise=args.no_final_step_noise, - temp_sampling=[args.temp_sampling_tr, args.temp_sampling_rot, - args.temp_sampling_tor], - temp_psi=[args.temp_psi_tr, args.temp_psi_rot, args.temp_psi_tor], - temp_sigma_data=[args.temp_sigma_data_tr, args.temp_sigma_data_rot, - args.temp_sigma_data_tor]) - - ligand_pos = np.asarray([complex_graph['ligand'].pos.cpu().numpy() + orig_complex_graph.original_center.cpu().numpy() for complex_graph in data_list]) - - # reorder predictions based on confidence output - if confidence is not None and isinstance(confidence_args.rmsd_classification_cutoff, list): - confidence = confidence[:, 0] - if confidence is not None: - confidence = confidence.cpu().numpy() - re_order = np.argsort(confidence)[::-1] - confidence = confidence[re_order] - ligand_pos = ligand_pos[re_order] - - # save predictions - write_dir = f'{args.out_dir}/{complex_name_list[idx]}' - for rank, pos in enumerate(ligand_pos): - mol_pred = copy.deepcopy(lig) - if score_model_args.remove_hs: mol_pred = RemoveAllHs(mol_pred) - if rank == 0: write_mol_with_coords(mol_pred, pos, os.path.join(write_dir, f'rank{rank+1}.sdf')) - write_mol_with_coords(mol_pred, pos, os.path.join(write_dir, f'rank{rank+1}_confidence{confidence[rank]:.2f}.sdf')) - - # save visualisation frames - if args.save_visualisation: - if confidence is not None: - for rank, batch_idx in enumerate(re_order): - visualization_list[batch_idx].write(os.path.join(write_dir, f'rank{rank+1}_reverseprocess.pdb')) + if confidence_test_dataset is not None: + confidence_complex_graph = confidence_test_dataset[idx] + if not confidence_complex_graph.success: + skipped += 1 + logger.warning(f"The confidence dataset did not contain {orig_complex_graph.name}. We are skipping this complex.") + continue + confidence_data_list = [copy.deepcopy(confidence_complex_graph) for _ in range(N)] else: - for rank, batch_idx in enumerate(ligand_pos): - visualization_list[batch_idx].write(os.path.join(write_dir, f'rank{rank+1}_reverseprocess.pdb')) - - except Exception as e: - print("Failed on", orig_complex_graph["name"], e) - failures += 1 + confidence_data_list = None + data_list = [copy.deepcopy(orig_complex_graph) for _ in range(N)] + randomize_position(data_list, score_model_args.no_torsion, False, score_model_args.tr_sigma_max, + initial_noise_std_proportion=args.initial_noise_std_proportion, + choose_residue=args.choose_residue) + + lig = orig_complex_graph.mol[0] + + # initialize visualisation + pdb = None + if args.save_visualisation: + visualization_list = [] + for graph in data_list: + pdb = PDBFile(lig) + pdb.add(lig, 0, 0) + pdb.add((orig_complex_graph['ligand'].pos + orig_complex_graph.original_center).detach().cpu(), 1, 0) + pdb.add((graph['ligand'].pos + graph.original_center).detach().cpu(), part=1, order=1) + visualization_list.append(pdb) + else: + visualization_list = None + + # run reverse diffusion + data_list, confidence = sampling(data_list=data_list, model=model, + inference_steps=args.actual_steps if args.actual_steps is not None else args.inference_steps, + tr_schedule=tr_schedule, rot_schedule=tr_schedule, tor_schedule=tr_schedule, + device=device, t_to_sigma=t_to_sigma, model_args=score_model_args, + visualization_list=visualization_list, confidence_model=confidence_model, + confidence_data_list=confidence_data_list, confidence_model_args=confidence_args, + batch_size=args.batch_size, no_final_step_noise=args.no_final_step_noise, + temp_sampling=[args.temp_sampling_tr, args.temp_sampling_rot, + args.temp_sampling_tor], + temp_psi=[args.temp_psi_tr, args.temp_psi_rot, args.temp_psi_tor], + temp_sigma_data=[args.temp_sigma_data_tr, args.temp_sigma_data_rot, + args.temp_sigma_data_tor]) + + ligand_pos = np.asarray([complex_graph['ligand'].pos.cpu().numpy() + orig_complex_graph.original_center.cpu().numpy() for complex_graph in data_list]) + + # reorder predictions based on confidence output + if confidence is not None and isinstance(confidence_args.rmsd_classification_cutoff, list): + confidence = confidence[:, 0] + if confidence is not None: + confidence = confidence.cpu().numpy() + re_order = np.argsort(confidence)[::-1] + confidence = confidence[re_order] + ligand_pos = ligand_pos[re_order] + + # save predictions + write_dir = f'{args.out_dir}/{complex_name_list[idx]}' + for rank, pos in enumerate(ligand_pos): + mol_pred = copy.deepcopy(lig) + if score_model_args.remove_hs: mol_pred = RemoveAllHs(mol_pred) + if rank == 0: write_mol_with_coords(mol_pred, pos, os.path.join(write_dir, f'rank{rank+1}.sdf')) + write_mol_with_coords(mol_pred, pos, os.path.join(write_dir, f'rank{rank+1}_confidence{confidence[rank]:.2f}.sdf')) + + # save visualisation frames + if args.save_visualisation: + if confidence is not None: + for rank, batch_idx in enumerate(re_order): + visualization_list[batch_idx].write(os.path.join(write_dir, f'rank{rank+1}_reverseprocess.pdb')) + else: + for rank, batch_idx in enumerate(ligand_pos): + visualization_list[batch_idx].write(os.path.join(write_dir, f'rank{rank+1}_reverseprocess.pdb')) -print(f'Failed for {failures} complexes') -print(f'Skipped {skipped} complexes') -print(f'Results are in {args.out_dir}') \ No newline at end of file + except Exception as e: + logger.warning("Failed on", orig_complex_graph["name"], e) + failures += 1 + + result_msg = f""" + Failed for {failures} / {test_ds_size} complexes. + Skipped {skipped} / {test_ds_size} complexes. +""" + if failures or skipped: + logger.warning(result_msg) + else: + logger.info(result_msg) + logger.info(f"Results saved in {args.out_dir}") + + +if __name__ == "__main__": + _args = get_parser().parse_args() + main(_args) diff --git a/utils/logging_utils.py b/utils/logging_utils.py new file mode 100644 index 000000000..65c21cbf8 --- /dev/null +++ b/utils/logging_utils.py @@ -0,0 +1,99 @@ +import logging +import multiprocessing +import os +import subprocess + + +LOGGER_NAME = "DiffDock" +LOGLEVEL_KEY = "DIFFDOCK_LOGLEVEL" + + +def _get_formatter(loglevel="INFO"): + warn_fmt = "[%(asctime)s] %(levelname)s -%(message)s" + debug_fmt = "[%(asctime)s] [%(filename)s:%(lineno)d] %(levelname)s - %(message)s" + fmt = debug_fmt if loglevel.upper() in {"DEBUG", "INFO"} else warn_fmt + return logging.Formatter( + fmt=fmt, + datefmt="%Y-%b-%d %H:%M:%S %Z", + ) + + +def remove_all_handlers(logger): + while logger.hasHandlers(): + logger.removeHandler(logger.handlers[0]) + + +def configure_logger(loglevel=None, logger_name=LOGGER_NAME, logfile=None): + """Do basic logger configuration and set our main logger""" + + # Set as environment variable so other processes can retrieve it + if loglevel is None: + loglevel = os.environ.get(LOGLEVEL_KEY, "WARNING") + else: + os.environ[LOGLEVEL_KEY] = loglevel + + logger = logging.getLogger(logger_name) + logger.setLevel(loglevel) + remove_all_handlers(logger) + logger.propagate = False + + formatter = _get_formatter(loglevel) + def _prep_handler(handler): + for ex_handler in logger.handlers: + if type(ex_handler) == type(handler): + # Remove old handler, don't want to double-handle + logger.removeHandler(ex_handler) + handler.setLevel(loglevel) + handler.setFormatter(formatter) + logger.addHandler(handler) + + sh = logging.StreamHandler() + _prep_handler(sh) + + if logfile is not None: + fh = logging.FileHandler(logfile, mode="a") + _prep_handler(fh) + + +def get_logger(base_name=LOGGER_NAME): + """ + Return a logger. + Use a different logger in each subprocess, though they should all have the same log level. + """ + pid = os.getpid() + logger_name = f"{base_name}-process-{pid}" + logger = logging.getLogger(logger_name) + if not logger.hasHandlers(): + configure_logger(logger_name=logger_name) + return logger + + +def get_git_revision_hash() -> str: + """ + Get the full git revision of the latest HEAD. + Note: This only works if run from git directory. + """ + return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip() + + +def get_git_revision_short_hash() -> str: + """ + Get the short git revision of the latest HEAD. + Note: This only works if run from git directory. + Returns: + + """ + return ( + subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) + .decode("ascii") + .strip() + ) + + +def check_git_uncommitted() -> bool: + changed_files = ( + subprocess.check_output(["git", "status", "-suno"]).decode("ascii").strip() + ) + lines = list(filter(lambda x: x, changed_files.split("\n"))) + return len(lines) > 0 +