Skip to content

Commit

Permalink
Merge pull request #343 from naik-aakash/enhance_icoxxlist_plotter
Browse files Browse the repository at this point in the history
Enhance icoxxlist plotter
  • Loading branch information
JaGeo authored Oct 28, 2024
2 parents 04756c9 + 109eede commit f265f0b
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 8 deletions.
66 changes: 58 additions & 8 deletions src/lobsterpy/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,18 @@ class IcohpDistancePlotter:
Defaults to False for COHPs.
"""

COLOR_PALETTE = [
"#e41a1c",
"#377eb8",
"#4daf4a",
"#984ea3",
"#ff7f00",
"#ffff33",
"#a65628",
"#f781bf",
"#999999",
]

def __init__(self, are_coops: bool = False, are_cobis: bool = False):
"""Initialize ICOHPs or ICOBI or ICOOP vs bond lengths plotter."""
self.are_coops = are_coops
Expand All @@ -1110,6 +1122,7 @@ def add_icohps(self, label: str, icohpcollection: IcohpCollection):
icohps = []
bond_len = []
atom_pairs = []
atom_types = []
orb_data = {} # type: ignore
for indx, bond_label in enumerate(icohpcollection._list_labels):
orb_data.update({bond_label: {}})
Expand All @@ -1120,8 +1133,10 @@ def add_icohps(self, label: str, icohpcollection: IcohpCollection):
atom1 = icohpcollection._list_atom1[indx]
atom2 = icohpcollection._list_atom2[indx]
atom_pairs.append(atom1 + "-" + atom2)
atom_types.append("-".join(sorted((atom1.strip("0123456789"), atom2.strip("0123456789")))))

self._icohps[label] = {
"atom_types": atom_types,
"atom_pairs": atom_pairs,
"bond_labels": icohpcollection._list_labels,
"icohps": icohps,
Expand All @@ -1132,23 +1147,29 @@ def add_icohps(self, label: str, icohpcollection: IcohpCollection):
def get_plot(
self,
ax: mpl.axes.Axes | None = None,
alpha: float = 0.4,
marker_size: float = 50,
marker_style: str = "o",
xlim: tuple[float, float] | None = None,
ylim: tuple[float, float] | None = None,
plot_negative: bool = True,
colors: list[str] | None = None,
color_interactions: bool = False,
):
"""
Get a matplotlib plot showing the COHP or COBI or COOP with respect to bond lengths.
:param ax: Existing Matplotlib Axes object to plot to.
:param alpha: sets the transparency of markers in scatter plots
:param marker_size: sets the size of markers in scatter plots
:param marker_style: sets type of marker used in plot
:param xlim: Specifies the x-axis limits. Defaults to None for
automatic determination.
:param ylim: Specifies the y-axis limits. Defaults to None for
automatic determination.
:param plot_negative: Will plot -1*ICOHPs. Works only for ICOHPs
:param colors: list of hex color codes to be used in plot
:param color_interactions: If True, will color the interactions based on atom types
Returns:
A matplotlib object.
Expand All @@ -1173,14 +1194,43 @@ def get_plot(
ax.set_ylabel(cohp_label)
ax.set_xlabel("Bond lengths (\u00c5)")

for label, data in self._icohps.items():
x = data["bond_lengths"]
if plot_negative and cohp_label == "ICOHP (eV)":
ax.set_ylabel("$-$" + cohp_label)
y = [-1 * icohp for icohp in data["icohps"]]
else:
y = data["icohps"]
if color_interactions:
colors = InteractiveCohpPlotter.COLOR_PALETTE if colors is None else colors
atom_types = []
for data in self._icohps.values():
atom_types.extend(list(set(data["atom_types"])))
color_dict = dict(zip(atom_types, colors))
for label, data in self._icohps.items():
x = data["bond_lengths"]
if plot_negative and cohp_label == "ICOHP (eV)":
ax.set_ylabel("$-$" + cohp_label)
y = [-1 * icohp for icohp in data["icohps"]]
else:
y = data["icohps"]
for i, pair in enumerate(data["atom_types"]):
ax.scatter(
x[i],
y[i],
c=color_dict[pair],
s=marker_size,
alpha=alpha,
marker=marker_style,
label=f"{label}-({pair})",
)

# Filter out duplicate labels and add only unique labels to the legend
handles, labels = plt.gca().get_legend_handles_labels()
unique = {label: handle for handle, label in zip(handles, labels)}
ax.legend(unique.values(), unique.keys())
else:
for label, data in self._icohps.items():
x = data["bond_lengths"]
if plot_negative and cohp_label == "ICOHP (eV)":
ax.set_ylabel("$-$" + cohp_label)
y = [-1 * icohp for icohp in data["icohps"]]
else:
y = data["icohps"]

ax.scatter(x, y, s=marker_size, marker=marker_style, label=label)
ax.scatter(x, y, s=marker_size, marker=marker_style, label=label)

return plt
11 changes: 11 additions & 0 deletions tests/plotting/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,17 @@ def test_plot_data(self, icohplist_nacl, icobilist_nacl, icooplist_nacl):
assert fig_x_lims == [0, 4]
assert fig_y_lims == [0, 6]

# test for colors in icoxx plotter
ax_icoop_c = icoop_plotter.get_plot(xlim=(0, 4), ylim=(0, 6), color_interactions=True).gca()
ax_icohp_c = icohp_plotter.get_plot(xlim=(0, 4), ylim=(0, 6), color_interactions=True).gca()
ax_icobi_c = icobi_plotter.get_plot(xlim=(0, 4), ylim=(0, 6), color_interactions=True).gca()
handles_icoop, labels_icoop = ax_icoop_c.get_legend_handles_labels()
handles_icohp, labels_icohp = ax_icohp_c.get_legend_handles_labels()
handles_icobi, labels_icobi = ax_icobi_c.get_legend_handles_labels()
assert len(set(labels_icoop)) == 3
assert len(set(labels_icobi)) == 3
assert len(set(labels_icohp)) == 3


class TestPlotterExceptions:
def test_plotter_exception(self, plot_analyse_nasi):
Expand Down

0 comments on commit f265f0b

Please sign in to comment.