Skip to content

Commit

Permalink
Forces user to choose -ss or -rmsd
Browse files Browse the repository at this point in the history
and fixes -cc option and logging.
  • Loading branch information
hmcezar committed Sep 21, 2023
1 parent 0660b03 commit 97b38ff
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 21 deletions.
8 changes: 4 additions & 4 deletions clusttraj/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,17 @@ def classify_structures_silhouette(
t_opt = t
labels_opt = hcl_labels

Logger.logger.info(f"Highest silhouette score: {ss_opt}")
Logger.logger.info(f"Highest silhouette score: {ss_opt}\n")

if t_opt.size > 1:
t_opt_str = ", ".join([str(t) for t in t_opt])
Logger.logger.info(
f"The following RMSD threshold values yielded the same optimial silhouette score: {t_opt_str}"
f"The following RMSD threshold values yielded the same optimial silhouette score: {t_opt_str}\n"
)
Logger.logger.info(f"The smallest RMSD of {t_opt[0]} has been adopted")
Logger.logger.info(f"The smallest RMSD of {t_opt[0]} has been adopted\n")
clusters = labels_opt[0]
else:
Logger.logger.info(f"Optimal RMSD threshold value: {t_opt}")
Logger.logger.info(f"Optimal RMSD threshold value: {t_opt}\n")
clusters = labels_opt

clust_opt.update({"optimal_cut": t_opt})
Expand Down
26 changes: 15 additions & 11 deletions clusttraj/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,6 @@ def configure_runtime(args_in: List[str]) -> ClustOptions:
type=extant_file,
help="path to the trajectory containing the conformations to be classified",
)
parser.add_argument(
"min_rmsd",
type=float,
help="value of RMSD used to classify structures as similar",
)
parser.add_argument(
"-f",
"--force-overwrite",
Expand Down Expand Up @@ -296,12 +291,6 @@ def configure_runtime(args_in: List[str]) -> ClustOptions:
action="store_true",
help="force a final Kabsch rotation before the RMSD computation (effect only when using -ns and -e)",
)
parser.add_argument(
"-ss",
"--silhouette-score",
action="store_true",
help="use the silhouette score to determine the optimal number of clusters",
)
parser.add_argument(
"--log",
type=str,
Expand All @@ -310,6 +299,21 @@ def configure_runtime(args_in: List[str]) -> ClustOptions:
help="log file (default: clusttraj.log)",
)

rmsd_criterion = parser.add_mutually_exclusive_group(required=True)

rmsd_criterion.add_argument(
"-ss",
"--silhouette-score",
action="store_true",
help="use the silhouette to determine the criterion to classify structures",
)
rmsd_criterion.add_argument(
"-rmsd",
"--min-rmsd",
type=float,
help="value of RMSD used to classify structures as similar",
)

io_group = parser.add_mutually_exclusive_group()
io_group.add_argument(
"-i",
Expand Down
1 change: 0 additions & 1 deletion clusttraj/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def main(args: list = None) -> None:
clust_opt.out_conf_fmt,
clust_opt.reorder_excl,
clust_opt.final_kabsch,
clust_opt.silhouette_score,
clust_opt.overwrite,
)

Expand Down
10 changes: 5 additions & 5 deletions test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_parse_args():


def test_configure_runtime(caplog):
clust_opt = configure_runtime(["test/ref/testtraj.xyz", "1.0", "-np", "1"])
clust_opt = configure_runtime(["test/ref/testtraj.xyz", "--min-rmsd", "1.0", "-np", "1"])

assert clust_opt.trajfile == "test/ref/testtraj.xyz"
assert clust_opt.min_rmsd == pytest.approx(1.0, abs=1e-8)
Expand All @@ -123,22 +123,22 @@ def test_configure_runtime(caplog):

with pytest.raises(SystemExit):
clust_opt = configure_runtime(
["test/ref/testtraj.xyz", "1.0", "-m", "nonexistent-method"]
["test/ref/testtraj.xyz", "--min-rmsd", "1.0", "-m", "nonexistent-method"]
)

with pytest.raises(SystemExit):
clust_opt = configure_runtime(
["test/ref/testtraj.xyz", "1.0", "--reorder-alg", "nonexistent-method"]
["test/ref/testtraj.xyz", "--min-rmsd", "1.0", "--reorder-alg", "nonexistent-method"]
)

with pytest.raises(SystemExit):
clust_opt = configure_runtime(
["test/ref/testtraj.xyz", "1.0", "-cc", "nonexistent-extension"]
["test/ref/testtraj.xyz", "--min-rmsd", "1.0", "-cc", "nonexistent-extension"]
)

with pytest.raises(SystemExit):
clust_opt = configure_runtime(
["test/ref/testtraj.xyz", "1.0", "-n", "-eex", "1"]
["test/ref/testtraj.xyz", "--min-rmsd", "1.0", "-n", "-eex", "1"]
)


Expand Down
1 change: 1 addition & 0 deletions test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ def test_main(tmp_path):
main(
[
"test/ref/testtraj.xyz",
"--min-rmsd",
"1.0",
"-np",
"1",
Expand Down

0 comments on commit 97b38ff

Please sign in to comment.