From a17d91ae24def6efe7050c880995f8b39d2d9be0 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Mon, 12 Aug 2024 12:58:07 -0500 Subject: [PATCH] Update the RMSD plotting script --- src/data/components/plot_dataset_rmsd.py | 38 ++++++++++++++++++++---- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/src/data/components/plot_dataset_rmsd.py b/src/data/components/plot_dataset_rmsd.py index 0139a1d..8294e62 100644 --- a/src/data/components/plot_dataset_rmsd.py +++ b/src/data/components/plot_dataset_rmsd.py @@ -105,6 +105,7 @@ def plot_dataset_rmsd( filtered_ids_to_keep_file: Optional[str] = None, filtered_ids_to_skip: Optional[Set[str]] = None, is_casp_dataset: bool = False, + public_plots: bool = True, accurate_rmsd_threshold: float = 4.0, accurate_tm_score_threshold: float = 0.7, ): @@ -119,6 +120,7 @@ def plot_dataset_rmsd( :param filtered_ids_to_keep_file: File containing IDs of sequences to keep. :param filtered_ids_to_skip: Set of IDs of sequences to skip. :param is_casp_dataset: Whether the dataset is a CASP dataset. + :param public_plots: Whether to save the public versions of the plots. :param accurate_rmsd_threshold: RMSD threshold for accurate predictions. :param accurate_tm_score_threshold: TM-score threshold for accurate predictions. """ @@ -135,7 +137,12 @@ def plot_dataset_rmsd( dataset_rows = [] - for pred_pdb_file in tqdm(os.listdir(pred_pdb_dir), desc=f"Plotting RMSD for {dataset_name}"): + dataset_suffix = " (Public)" if is_casp_dataset and public_plots else "" + + for pred_pdb_file in tqdm( + os.listdir(pred_pdb_dir), + desc=f"Plotting RMSD for {dataset_name}{dataset_suffix}", + ): pdb_id = os.path.splitext(os.path.basename(pred_pdb_file))[0].split("_holo")[0] if filtered_ids_to_keep is not None and pdb_id not in filtered_ids_to_keep: @@ -193,10 +200,10 @@ def plot_dataset_rmsd( / dataset_df.shape[0] ) logging.info( - f"For the {dataset_name} dataset, {accurate_predictions_percent * 100:.2f}% of the predictions have RMSD < {accurate_rmsd_threshold} and TM-score > {accurate_tm_score_threshold}." + f"For the {dataset_name}{dataset_suffix} dataset, {accurate_predictions_percent * 100:.2f}% of the predictions have RMSD < {accurate_rmsd_threshold} and TM-score > {accurate_tm_score_threshold}." ) - plot_dir = Path(output_dir) / ("public_plots" if is_casp_dataset else "plots") + plot_dir = Path(output_dir) / ("public_plots" if is_casp_dataset and public_plots else "plots") plot_dir.mkdir(exist_ok=True) plt.clf() @@ -280,11 +287,32 @@ def main(cfg: DictConfig): ), usalign_exec_path=cfg.usalign_exec_path, filtered_ids_to_skip={ - "T1170" - }, # NOTE: We don't score this target due to CASP internal parsing issues + "T1127v2", + "T1146", + "T1170", + "T1181", + "T1186", + }, # NOTE: We don't score `T1170` due to CASP internal parsing issues is_casp_dataset=True, + public_plots=True, ) + # plot_dataset_rmsd( + # "CASP15 Set", + # os.path.join(cfg.data_dir, "casp15_set", "predicted_structures"), + # os.path.join(cfg.data_dir, "casp15_set", "targets"), + # os.path.join( + # cfg.data_dir, + # "casp15_set", + # ), + # usalign_exec_path=cfg.usalign_exec_path, + # filtered_ids_to_skip={ + # "T1170", + # }, # NOTE: We don't score `T1170` due to CASP internal parsing issues + # is_casp_dataset=True, + # public_plots=False, + # ) + if __name__ == "__main__": main()