Skip to content

Commit

Permalink
Merge pull request #13 from hmcezar/hotfix-ci-errors
Browse files Browse the repository at this point in the history
Fixes CI errors and move to qmllib.
  • Loading branch information
hmcezar authored Dec 19, 2024
2 parents a568a59 + 9b74947 commit 5816856
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 68 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4
Expand Down
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,25 @@ The following libraries are required:
- [NumPy](http://www.numpy.org/)
- [OpenBabel](http://openbabel.org/)
- [RMSD](https://github.com/charnley/rmsd)
- [QML](https://github.com/qmlcode/qml)
- [SciPy](https://www.scipy.org/)
- [scikit-learn](http://scikit-learn.org/stable/index.html)
- [matplotlib](https://matplotlib.org/)

We use the development branch of `QML` as one of the reordering algorithms.
Since the development of `QML` has been quite slow, we provide our [own branch](https://github.com/hmcezar/qml/tree/develop) in which installation using modern versions of `numpy` is possible.
We also have [qmllib](https://github.com/qmlcode/qmllib) as an optional dependency as one of the reordering algorithms.

For `openbabel`, we use the `pip` package `openbabel-wheel` which provides pre-built `openbabel` packages for Linux and MacOS.
More details can be seen in the [projects' GitHub page](https://github.com/njzjz/openbabel-wheel).


You can install the package using `pip`
```bash
pip install clusttraj
```

If you want to use the `qmllib` reordering algorithm, you can install it with:
```bash
pip install clusttraj[qml]
```

## Usage
To see all the options run the script with the `-h` command option:
```bash
Expand Down
17 changes: 12 additions & 5 deletions clusttraj/distmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,9 @@ def compute_distmat_line(
# whereins = np.where(
# np.isin(np.arange(natoms), reorderexcl[soluexcl]) is True
# )
whereins = np.where(np.atleast_1d(np.isin(np.arange(natoms), reorderexcl[soluexcl])))
whereins = np.where(
np.atleast_1d(np.isin(np.arange(natoms), reorderexcl[soluexcl]))
)
Psolu = np.insert(
Pview,
[x - whereins[0].tolist().index(x) for x in whereins[0]],
Expand Down Expand Up @@ -255,21 +257,26 @@ def compute_distmat_line(

# reorder the solvent atoms separately
if reorder:
# if the solute is specified, reorder just the solvent atoms in this step
if nsatoms:
exclusions = np.unique(np.concatenate((np.arange(natoms), reorderexcl)))
else:
exclusions = reorderexcl

# get the view without the excluded atoms
view = np.delete(np.arange(len(P)), reorderexcl)
view = np.delete(np.arange(len(P)), exclusions)
Pview = P[view]
Paview = Pa[view]

prr = reorder(Qa[view], Paview, Q[view], Pview)
Pview = Pview[prr]

# build the total molecule with the reordered atoms
# whereins = np.where(np.isin(np.arange(len(P)), reorderexcl) is True)
whereins = np.where(np.atleast_1d(np.isin(np.arange(len(P)), reorderexcl)))
whereins = np.where(np.atleast_1d(np.isin(np.arange(len(P)), exclusions)))
Pr = np.insert(
Pview,
[x - whereins[0].tolist().index(x) for x in whereins[0]],
P[reorderexcl],
P[exclusions],
axis=0,
)

Expand Down
49 changes: 32 additions & 17 deletions clusttraj/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from dataclasses import dataclass
from .utils import get_mol_info

if importlib.util.find_spec("qml"):
if importlib.util.find_spec("qmllib"):
has_qml = True
else:
has_qml = False
Expand Down Expand Up @@ -377,7 +377,7 @@ def configure_runtime(args_in: List[str]) -> ClustOptions:

if (args.reorder_alg == "qml") and (not has_qml):
parser.error(
"You must have the development branch of qml installed in order to use it as a reorder method."
"You must have the optional dependency `qmllib` installed in order to use it as a reorder method."
)

if args.clusters_configurations:
Expand Down Expand Up @@ -544,9 +544,11 @@ def parse_args(args: argparse.Namespace) -> ClustOptions:

options_dict = {
"solute_natoms": args.natoms_solute,
"reorder_excl": np.asarray([x - 1 for x in args.reorder_exclusions], np.int32)
if args.reorder_exclusions
else np.asarray([], np.int32),
"reorder_excl": (
np.asarray([x - 1 for x in args.reorder_exclusions], np.int32)
if args.reorder_exclusions
else np.asarray([], np.int32)
),
"exclusions": bool(args.reorder_exclusions),
"reorder_alg_name": args.reorder_alg,
"reorder_alg": None,
Expand All @@ -556,12 +558,12 @@ def parse_args(args: argparse.Namespace) -> ClustOptions:
"out_clust_name": args.outputclusters,
"summary_name": basenameout + ".out",
"save_confs": bool(args.clusters_configurations),
"out_conf_name": basenameout + "_confs"
if args.clusters_configurations
else None,
"out_conf_fmt": args.clusters_configurations
if args.clusters_configurations
else None,
"out_conf_name": (
basenameout + "_confs" if args.clusters_configurations else None
),
"out_conf_fmt": (
args.clusters_configurations if args.clusters_configurations else None
),
"plot": bool(args.plot),
"evo_name": basenameout + "_evo.pdf" if args.plot else None,
"dendrogram_name": basenameout + "_dendrogram.pdf" if args.plot else None,
Expand Down Expand Up @@ -777,7 +779,9 @@ def save_clusters_config(
# whereins = np.where(
# np.isin(np.arange(natoms), reorderexcl[soluexcl]) is True
# )
whereins = np.where(np.atleast_1d(np.isin(np.arange(natoms), reorderexcl)))
whereins = np.where(
np.atleast_1d(np.isin(np.arange(natoms), reorderexcl))
)
Psolu = np.insert(
Pview,
[x - whereins[0].tolist().index(x) for x in whereins[0]],
Expand Down Expand Up @@ -840,21 +844,30 @@ def save_clusters_config(

# reorder the solvent atoms separately
if reorder:
# if the solute is specified, reorder just the solvent atoms in this step
if nsatoms:
exclusions = np.unique(
np.concatenate((np.arange(natoms), reorderexcl))
)
else:
exclusions = reorderexcl

# get the view without the excluded atoms
view = np.delete(np.arange(len(P)), reorderexcl)
view = np.delete(np.arange(len(P)), exclusions)
Pview = P[view]
Paview = Pa[view]

prr = reorder(Qa[view], Paview, Q[view], Pview)
Pview = Pview[prr]

# build the total molecule with the reordered atoms
# whereins = np.where(np.isin(np.arange(len(P)), reorderexcl) is True)
whereins = np.where(np.atleast_1d(np.isin(np.arange(len(P)), reorderexcl)))
whereins = np.where(
np.atleast_1d(np.isin(np.arange(len(P)), exclusions))
)
Pr = np.insert(
Pview,
[x - whereins[0].tolist().index(x) for x in whereins[0]],
P[reorderexcl],
P[exclusions],
axis=0,
)
else:
Expand Down Expand Up @@ -883,4 +896,6 @@ def save_clusters_config(

# closes the file for the cnum cluster
outfile.close()
# type: ignore


# type: ignore
3 changes: 0 additions & 3 deletions clusttraj/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,16 @@
from scipy.cluster.hierarchy import cophenet
from typing import Tuple
import numpy as np
# from .io import ClustOptions


def compute_metrics(
# clust_opt: ClustOptions,
distmat: np.ndarray,
z_matrix: np.ndarray,
clusters: np.ndarray,
) -> Tuple[np.float64, np.float64, np.float64, np.float64]:
"""Compute metrics to assess the performance of the clustering procedure.
Args:
# clust_opt (ClustOptions): The clustering options.
distmat: The distance matrix.
z_matrix (np.ndarray): The Z-matrix from hierarchical clustering procedure.
clusters (np.ndarray): The cluster classifications for each sample.
Expand Down
55 changes: 27 additions & 28 deletions clusttraj/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
from .io import ClustOptions


def plot_clust_evo(
clust_opt: ClustOptions,
clusters: np.ndarray
) -> None:
def plot_clust_evo(clust_opt: ClustOptions, clusters: np.ndarray) -> None:
"""Plot the evolution of cluster classification over the given samples.
Args:
Expand All @@ -25,30 +22,34 @@ def plot_clust_evo(
Returns:
None
"""

# Define a color for the lines
line_color = (0, 0, 0, 0.5)

# plot evolution with o cluster in trajectory
plt.figure(figsize=(10, 6))

# Set the y-axis to only show integers
plt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))

# Increase tick size and font size
plt.tick_params(axis='both', which='major', direction='in', labelsize=12)
plt.tick_params(axis="both", which="major", direction="in", labelsize=12)

plt.plot(range(1, len(clusters) + 1), clusters, markersize=4, color=line_color)
plt.scatter(range(1, len(clusters) + 1), clusters, marker="o", c=clusters, cmap=plt.cm.nipy_spectral)
plt.scatter(
range(1, len(clusters) + 1),
clusters,
marker="o",
c=clusters,
cmap=plt.cm.nipy_spectral,
)
plt.xlabel("Sample Index", fontsize=14)
plt.ylabel("Cluster classification", fontsize=14)
plt.savefig(clust_opt.evo_name, bbox_inches="tight")


def plot_dendrogram(
clust_opt: ClustOptions,
clusters: np.ndarray,
Z: np.ndarray
clust_opt: ClustOptions, clusters: np.ndarray, Z: np.ndarray
) -> None:
"""Plot a dendrogram based on hierarchical clustering.
Expand All @@ -65,18 +66,22 @@ def plot_dendrogram(
plt.title("Hierarchical Clustering Dendrogram", fontsize=20)
# plt.xlabel("Sample Index", fontsize=14)
plt.ylabel(r"RMSD ($\AA$)", fontsize=18)
plt.tick_params(axis='y', labelsize=18)
plt.tick_params(axis="y", labelsize=18)

# Define a color for the dashed and non-cluster lines
line_color = (0, 0, 0, 0.5)

# Add a horizontal line at the minimum RMSD value and set the threshold
if clust_opt.silhouette_score:
if isinstance(clust_opt.optimal_cut, (np.ndarray, list)):
plt.axhline(clust_opt.optimal_cut[0], linestyle="--", linewidth=2, color=line_color)
plt.axhline(
clust_opt.optimal_cut[0], linestyle="--", linewidth=2, color=line_color
)
threshold = clust_opt.optimal_cut[0]
elif isinstance(clust_opt.optimal_cut, (float, np.float32, np.float64)):
plt.axhline(clust_opt.optimal_cut, linestyle="--", linewidth=2, color=line_color)
plt.axhline(
clust_opt.optimal_cut, linestyle="--", linewidth=2, color=line_color
)
threshold = clust_opt.optimal_cut
else:
raise ValueError("optimal_cut must be a float or np.ndarray")
Expand All @@ -86,9 +91,9 @@ def plot_dendrogram(

# Use the 'nipy_spectral' cmap to color the dendrogram
unique_clusters = np.unique(clusters)
cmap = cm.get_cmap('nipy_spectral', len(unique_clusters))
cmap = cm.get_cmap("nipy_spectral", len(unique_clusters))
colors = [to_hex(cmap(i)) for i in range(cmap.N)]

hierarchy.set_link_color_palette(colors)

# Plot the dendrogram
Expand All @@ -98,18 +103,14 @@ def plot_dendrogram(
# leaf_font_size=8.0, # Font size for the x axis labels
no_labels=True,
color_threshold=threshold,
above_threshold_color=line_color
above_threshold_color=line_color,
)

# Save the dendrogram to a file
plt.savefig(clust_opt.dendrogram_name, bbox_inches="tight")


def plot_mds(
clust_opt: ClustOptions,
clusters: np.ndarray,
distmat: np.ndarray
) -> None:
def plot_mds(clust_opt: ClustOptions, clusters: np.ndarray, distmat: np.ndarray) -> None:
"""Plot the multidimensional scaling (MDS) of the distance matrix.
Args:
Expand Down Expand Up @@ -139,7 +140,7 @@ def plot_mds(
coords = mds.fit_transform(squareform(distmat))

# Set the figure size
plt.figure(figsize=(6,6))
plt.figure(figsize=(6, 6))

# Configure tick parameters
plt.tick_params(
Expand All @@ -165,9 +166,7 @@ def plot_mds(


def plot_tsne(
clust_opt: ClustOptions,
clusters: np.ndarray,
distmat: np.ndarray
clust_opt: ClustOptions, clusters: np.ndarray, distmat: np.ndarray
) -> None:
"""Plot the t-distributed Stochastic Neighbor Embedding 2D plot of the clustering.
Expand All @@ -194,11 +193,11 @@ def plot_tsne(

# Define a list of unique colors for each cluster
unique_clusters = np.unique(clusters)
cmap = cm.get_cmap('nipy_spectral', len(unique_clusters))
cmap = cm.get_cmap("nipy_spectral", len(unique_clusters))
colors = [cmap(i) for i in range(len(unique_clusters))]

# Set the figure size
plt.figure(figsize=(6,6))
plt.figure(figsize=(6, 6))

# Configure tick parameters
plt.tick_params(
Expand Down
10 changes: 3 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ classifiers = [
"Intended Audience :: Science/Research",
"Programming Language :: Python",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Environment :: Console",
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
"Operating System :: OS Independent",
Expand All @@ -43,16 +43,12 @@ clusttraj = "clusttraj.main:main"
[project.optional-dependencies]
test = ["pytest", "pytest-cov[all]"]
docs = ["sphinx", "sphinx_rtd_theme"]
lint = ["ruff", "black"]
qml = ["qml"] # you can install from git+https://github.com/hmcezar/qml@develop
lint = ["black"]
qml = ["qmllib"]
all = ["clusttraj[test,docs,lint,qml]"]

[tool.setuptools]
packages = ["clusttraj"]

[tool.black]
line-length = 89

[tool.ruff]
line-length = 89
ignore = ["E501"]
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy
rmsd
rmsd>=1.6.0
scipy
scikit-learn
matplotlib
Expand Down
Loading

0 comments on commit 5816856

Please sign in to comment.