Skip to content

Commit

Permalink
Update the RMSD plotting script
Browse files Browse the repository at this point in the history
  • Loading branch information
amorehead committed Aug 12, 2024
1 parent 60489ab commit a17d91a
Showing 1 changed file with 33 additions and 5 deletions.
38 changes: 33 additions & 5 deletions src/data/components/plot_dataset_rmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def plot_dataset_rmsd(
filtered_ids_to_keep_file: Optional[str] = None,
filtered_ids_to_skip: Optional[Set[str]] = None,
is_casp_dataset: bool = False,
public_plots: bool = True,
accurate_rmsd_threshold: float = 4.0,
accurate_tm_score_threshold: float = 0.7,
):
Expand All @@ -119,6 +120,7 @@ def plot_dataset_rmsd(
:param filtered_ids_to_keep_file: File containing IDs of sequences to keep.
:param filtered_ids_to_skip: Set of IDs of sequences to skip.
:param is_casp_dataset: Whether the dataset is a CASP dataset.
:param public_plots: Whether to save the public versions of the plots.
:param accurate_rmsd_threshold: RMSD threshold for accurate predictions.
:param accurate_tm_score_threshold: TM-score threshold for accurate predictions.
"""
Expand All @@ -135,7 +137,12 @@ def plot_dataset_rmsd(

dataset_rows = []

for pred_pdb_file in tqdm(os.listdir(pred_pdb_dir), desc=f"Plotting RMSD for {dataset_name}"):
dataset_suffix = " (Public)" if is_casp_dataset and public_plots else ""

for pred_pdb_file in tqdm(
os.listdir(pred_pdb_dir),
desc=f"Plotting RMSD for {dataset_name}{dataset_suffix}",
):
pdb_id = os.path.splitext(os.path.basename(pred_pdb_file))[0].split("_holo")[0]

if filtered_ids_to_keep is not None and pdb_id not in filtered_ids_to_keep:
Expand Down Expand Up @@ -193,10 +200,10 @@ def plot_dataset_rmsd(
/ dataset_df.shape[0]
)
logging.info(
f"For the {dataset_name} dataset, {accurate_predictions_percent * 100:.2f}% of the predictions have RMSD < {accurate_rmsd_threshold} and TM-score > {accurate_tm_score_threshold}."
f"For the {dataset_name}{dataset_suffix} dataset, {accurate_predictions_percent * 100:.2f}% of the predictions have RMSD < {accurate_rmsd_threshold} and TM-score > {accurate_tm_score_threshold}."
)

plot_dir = Path(output_dir) / ("public_plots" if is_casp_dataset else "plots")
plot_dir = Path(output_dir) / ("public_plots" if is_casp_dataset and public_plots else "plots")
plot_dir.mkdir(exist_ok=True)

plt.clf()
Expand Down Expand Up @@ -280,11 +287,32 @@ def main(cfg: DictConfig):
),
usalign_exec_path=cfg.usalign_exec_path,
filtered_ids_to_skip={
"T1170"
}, # NOTE: We don't score this target due to CASP internal parsing issues
"T1127v2",
"T1146",
"T1170",
"T1181",
"T1186",
}, # NOTE: We don't score `T1170` due to CASP internal parsing issues
is_casp_dataset=True,
public_plots=True,
)

# plot_dataset_rmsd(
# "CASP15 Set",
# os.path.join(cfg.data_dir, "casp15_set", "predicted_structures"),
# os.path.join(cfg.data_dir, "casp15_set", "targets"),
# os.path.join(
# cfg.data_dir,
# "casp15_set",
# ),
# usalign_exec_path=cfg.usalign_exec_path,
# filtered_ids_to_skip={
# "T1170",
# }, # NOTE: We don't score `T1170` due to CASP internal parsing issues
# is_casp_dataset=True,
# public_plots=False,
# )


if __name__ == "__main__":
main()

0 comments on commit a17d91a

Please sign in to comment.