From 7df0777550d749a959b62d1faa491bef7c5b4866 Mon Sep 17 00:00:00 2001 From: Matteo Giantomassi Date: Sun, 15 Sep 2024 11:58:22 +0200 Subject: [PATCH] Add Robot.getattrs_alleq method --- abipy/abilab.py | 2 + abipy/abio/robots.py | 53 +++++++++--- abipy/eph/gpath.py | 202 +++++++++++++++++++++++++------------------ 3 files changed, 158 insertions(+), 99 deletions(-) diff --git a/abipy/abilab.py b/abipy/abilab.py index a44e8e9a5..b7387e2f6 100644 --- a/abipy/abilab.py +++ b/abipy/abilab.py @@ -75,6 +75,7 @@ from abipy.eph.rta import RtaFile, RtaRobot from abipy.eph.transportfile import TransportFile from abipy.eph.gstore import GstoreFile +from abipy.eph.gpath import GpathFile from abipy.wannier90 import WoutFile, AbiwanFile, AbiwanRobot from abipy.electrons.lobster import CoxpFile, ICoxpFile, LobsterDoscarFile, LobsterInput, LobsterAnalyzer @@ -173,6 +174,7 @@ def _straceback(): ("A2F.nc", A2fFile), ("SIGEPH.nc", SigEPhFile), ("GSTORE.nc", GstoreFile), + ("GPATH.nc", GpathFile), ("TRANSPORT.nc",TransportFile), ("RTA.nc",RtaFile), ("V1SYM.nc", V1symFile), diff --git a/abipy/abio/robots.py b/abipy/abio/robots.py index 3c7200bd5..4be9b0b76 100644 --- a/abipy/abio/robots.py +++ b/abipy/abio/robots.py @@ -613,23 +613,50 @@ def _repr_html_(self) -> str: """Integration with jupyter_ notebooks.""" return '
    \n{}\n
'.format("\n".join("
  • %s
  • " % label for label, abifile in self.items())) - def getattr_alleq(self, aname : str): + def getattrs_alleq(self, *aname_args) -> list: """ - Return the value of attribute aname. + Return list of attribute values for each attribute name in *aname_args. + """ + return [self.getattr_alleq(aname) for aname in aname_args] + + def getattr_alleq(self, aname: str): + """ + Return the value of attribute aname. Try firs in self then in self.r Raises ValueError if value is not the same across all the files in the robot. """ - val1 = getattr(self.abifiles[0], aname) - for abifile in self.abifiles[1:]: - val2 = getattr(abifile, aname) - if isinstance(val1, (str, int, float)): - eq = val1 == val2 - elif isinstance(val1, np.ndarray): - eq = np.allclose(val1, val2) - if not eq: - raise ValueError(f"Different values of {aname=}, {val1=}, {val2=}") - - return val1 + def get_obj_list(what: str): + if what == "abifiles": + return self.abifiles + elif what == "r": + return [abifile.r for abifile in self.abifiles] + + raise ValueError(f"Invalid {what=}") + + err_msg = [] + + for what in ["abifiles", "r"]: + objs = get_obj_list(what) + + try: + val1 = getattr(objs[0], aname) + except AttributeError as exc: + err_msg.append(str(exc)) + continue + + for obj in objs[1:]: + val2 = getattr(obj, aname) + if isinstance(val1, (str, int, float)): + eq = val1 == val2 + elif isinstance(val1, np.ndarray): + eq = np.allclose(val1, val2) + if not eq: + raise ValueError(f"Different values of {aname=}, {val1=}, {val2=}") + + return val1 + + if err_msg: + raise ValueError("\n".join(err_msg)) @property def abifiles(self) -> list: diff --git a/abipy/eph/gpath.py b/abipy/eph/gpath.py index 6c7807b72..6d771a542 100644 --- a/abipy/eph/gpath.py +++ b/abipy/eph/gpath.py @@ -18,7 +18,7 @@ from abipy.tools.typing import PathLike #from abipy.tools.numtools import nparr_to_df from abipy.tools.plotting import (add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt, set_axlims, set_visible, - rotate_ticklabels, ax_append_title, set_ax_xylabels, linestyles, Marker, set_grid_legend) + rotate_ticklabels, ax_append_title, set_ax_xylabels, linestyles, Marker, set_grid_legend, set_axlims) from abipy.electrons.ebands import ElectronBands, RobotWithEbands from abipy.dfpt.phonons import PhononBands from abipy.dfpt.phtk import NonAnalyticalPh @@ -27,6 +27,13 @@ from abipy.eph.common import BaseEphReader +def k2s(k_vector, fmt=".3f", threshold = 1e-8) -> str: + k_vector = np.asarray(k_vector) + k_vector[np.abs(k_vector) < threshold] = 0 + + return "[" + ", ".join(f"{x:.3f}" for x in k_vector) + "]" + + class GpathFile(AbinitNcFile, Has_Structure, NotebookWriter): """ This file stores the e-ph matrix elements along a k/q path @@ -112,9 +119,23 @@ def to_string(self, verbose: int=0) -> str: return "\n".join(lines) + @staticmethod + def _get_which_g_list(which_g: str) -> list[str]: + all_choices = ["avg", "raw"] + if which_g == "all": + return all_choices + + if which_g not in all_choices: + raise ValueError(f"Invalid {which=}, should be in {all_choices=}") + + return [which_g] + + def _get_band_range(self, band_range): + return (self.r.bstart, self.r.bstop) if band_range is None else band_range + @add_fig_kwargs - def plot_g_qpath(self, band_range=None, which_g="avg", with_qexp: int=0, scale=1, - with_phbands=True, with_ebands=False, + def plot_g_qpath(self, band_range=None, which_g="avg", with_qexp: int=0, scale=1, gmax_mev=250, + ph_modes=None, with_phbands=True, with_ebands=False, ax_mat=None, fontsize=8, **kwargs) -> Figure: """ Plot the averaged |g(k,q)| in meV units along the q-path @@ -124,21 +145,22 @@ def plot_g_qpath(self, band_range=None, which_g="avg", with_qexp: int=0, scale=1 which_g: "avg" to plot the symmetrized |g|, "raw" for unsymmetrized |g|."all" for both. with_qexp: Multiply |g(q)| by |q|^{with_qexp}. scale: Scaling factor for the marker size used when with_phbands is True. + gmax_mev: Show results up to gmax in meV. + ph_modes: List of ph branch indices to show (start from 0). If None all modes are shown. with_phbands: False if phonon bands should now be displayed. with_ebands: False if electron bands should now be displayed. ax_mat: List of |matplotlib-Axes| or None if a new figure should be created. fontsize: fontsize for legends and titles """ - nrows, ncols = 1 + int((np.array([with_ebands, with_phbands]) == True).sum()), self.r.nsppol - which_g_list = [which_g] - if which_g == "all": - which_g_list = ["avg", "raw"] - nrows += 1 + which_g_list = self._get_which_g_list(which_g) + nrows, ncols = len(which_g_list) + int((np.array([with_ebands, with_phbands]) == True).sum()), self.r.nsppol ax_mat, fig, plt = get_axarray_fig_plt(ax_mat, nrows=nrows, ncols=ncols, sharex=False, sharey=False, squeeze=False) marker_color = "gold" - band_range = (self.r.bstart, self.r.bstop) if band_range is None else band_range + band_range = self._get_band_range(band_range) + + #facts_q, g_label, g_units = self.get_info(which_g, with_qexp) facts_q = np.ones(len(self.phbands.qpoints)) if with_qexp == 0 else \ np.array([qpt.norm for qpt in self.phbands.qpoints]) ** with_qexp @@ -150,18 +172,22 @@ def plot_g_qpath(self, band_range=None, which_g="avg", with_qexp: int=0, scale=1 ax_cnt = -1 for which_g in which_g_list: - # Select data according to which_g and multiply by facts_q + # Select ys according to which_g and multiply by facts_q g_nuq = dict(avg=g_nuq_avg, raw=g_nuq_raw)[which_g] * facts_q[None,:] + # Plot g_nu(q) ax_cnt += 1 ax = ax_mat[ax_cnt, spin] for nu in range(self.r.natom3): + if ph_modes is not None and nu not in ph_modes: continue ax.plot(g_nuq[nu], label=f"{nu=}") - self.phbands.decorate_ax(ax, units="meV") g_label = r"$|g^{\text{%s}}_{\mathbf{q}}|$ %s" % (which_g, q_label) set_grid_legend(ax, fontsize, ylabel="%s %s" % (g_label, g_units)) + if gmax_mev is not None and with_qexp == 0: + set_axlims(ax, [0, gmax_mev], "y") + if with_phbands: # Plot phonons bands + averaged g(q) as markers ax_cnt += 1 @@ -188,15 +214,15 @@ def plot_g_qpath(self, band_range=None, which_g="avg", with_qexp: int=0, scale=1 # Add title. if (kpt_name := self.structure.findname_in_hsym_stars(self.r.eph_fix_wavec)) is None: - kpt_name = str(self.r.eph_fix_wavec) + qpt_name = k2s(self.r.eph_fix_wavec) fig.suptitle(f"k = {kpt_name}" + f" m, n = {band_range[0]} - {band_range[1] - 1}") return fig @add_fig_kwargs - def plot_g_kpath(self, band_range=None, which_g="sym", scale=1, with_ebands=True, - ax_mat=None, fontsize=8, **kwargs) -> Figure: + def plot_g_kpath(self, band_range=None, which_g="avg", scale=1, gmax_mev=250, ph_modes=None, + with_ebands=True, ax_mat=None, fontsize=8, **kwargs) -> Figure: """ Plot the averaged |g(k,q)| in meV units along the k-path @@ -204,60 +230,49 @@ def plot_g_kpath(self, band_range=None, which_g="sym", scale=1, with_ebands=True band_range: Band range that will be averaged over (python convention). which_g: "avg" to plot the symmetrized |g|, "raw" for unsymmetrized |g|."all" for both. scale: Scaling factor for the marker size used when with_phbands is True. + gmax_mev: Show results up to gmax in meV. + ph_modes: List of ph branch indices to show (start from 0). If None all modes are show. with_ebands: False if electron bands should now be displayed. ax_mat: List of |matplotlib-Axes| or None if a new figure should be created. fontsize: fontsize for legends and titles """ - nrows, ncols = 1 + int((np.array([with_ebands]) == True).sum()), self.r.nsppol - which_g_list = [which_g] - if which_g == "all": - which_g_list = ["avg", "raw"] - nrows += 1 + which_g_list = self._get_which_g_list(which_g) + nrows, ncols = len(which_g_list) + int((np.array([with_ebands]) == True).sum()), self.r.nsppol ax_mat, fig, plt = get_axarray_fig_plt(ax_mat, nrows=nrows, ncols=ncols, sharex=False, sharey=False, squeeze=False) - marker_color = "gold" - band_range = (self.r.bstart, self.r.bstop) if band_range is None else band_range + band_range = self._get_band_range(band_range) for spin in range(self.r.nsppol): g_nuk_avg, g_nuk_raw = self.r.get_gnuk_average_spin(spin, band_range) ax_cnt = -1 for which_g in which_g_list: - # Select data according to which_g + # Select ys according to which_g g_nuk = dict(avg=g_nuk_avg, raw=g_nuk_raw)[which_g] + # Plot g_nu(q) ax_cnt += 1 ax = ax_mat[ax_cnt, spin] for nu in range(self.r.natom3): + if ph_modes is not None and nu not in ph_modes: continue ax.plot(g_nuk[nu], label=f"{which_g} {nu=}") - # Plot g(k) self.ebands_k.decorate_ax(ax, units="meV") - set_grid_legend(ax, fontsize, ylabel=r"$|g^{\text{%s}}(\mathbf{k})|$ (meV)" % (which_g)) + set_grid_legend(ax, fontsize, ylabel=r"$|g^{\text{%s}}_{\mathbf{k}}|$ (meV)" % (which_g)) + if gmax_mev is not None: + set_axlims(ax, [0, gmax_mev], "y") if with_ebands: - # Plot electron bands + averaged g(k) as markers + # Plot electron bands ax_cnt += 1 - points = None - #x, y, s = [], [], [] - #for ik, kpoint in enumerate(self.ebands_k.kpoints): - # omegas_nu = self.phbands.phfreqs[iq,:] - # for w, g2 in zip(omegas_nu, g_nuk[:,iq], strict=True): - # x.append(iq); y.append(w); s.append(scale * g2) - - #points = Marker(x, y, s, color=marker_color, edgecolors='gray', alpha=0.8, - # label=r'$|g^{\text{avg}}(\mathbf{k})|$ (meV)') - ax = ax_mat[ax_cnt, spin] self.ebands_k.plot(ax=ax, spin=spin, band_range=band_range, with_gaps=False, show=False) set_grid_legend(ax, fontsize) #, xlabel=r"Wavevector $\mathbf{q}$") - #self.phbands.plot(ax=ax, points=points, show=False) - if (qpt_name := self.structure.findname_in_hsym_stars(self.r.eph_fix_wavec)) is None: - qpt_name = str(self.r.eph_fix_wavec) + qpt_name = k2s(self.r.eph_fix_wavec) fig.suptitle(f"q = {qpt_name}" + f" m, n = {band_range[0]} - {band_range[1] - 1}") @@ -269,7 +284,7 @@ def yield_figs(self, **kwargs): # pragma: no cover """ if self.r.eph_fix_korq == "k": #yield self.ebands_kq.plot(show=False) - yield self.phbands.plot(show=False) + #yield self.phbands.plot(show=False) yield self.plot_g_qpath() if self.r.eph_fix_korq == "q": @@ -320,7 +335,11 @@ def __init__(self, filepath: PathLike): # Read important variables. self.eph_fix_korq = self.read_string("eph_fix_korq") + if self.eph_fix_korq not in {"k", "q"}: + raise ValueError(f"Invalid value for {self.eph_fix_korq=}") self.eph_fix_wavec = self.read_value("eph_fix_wavevec") + self.dbdb_add_lr = self.read_value("dvdb_add_lr") + #self.used_ftinterp = self.read_value("used_ftinterp") #self.completed = self.read_value("gstore_completed") # Note conversion Fortran --> C for the bstart index. @@ -594,57 +613,68 @@ class GpathRobot(Robot, RobotWithEbands): "t05o_GPATH.nc", ]) - .. rubric:: Inheritance Diagram - .. inheritance-diagram:: GstoreRobot + .. inheritance-diagram:: GpathRobot """ EXT = "GPATH" - #def neq(self, ref_basename: str | None = None, verbose: int = 0) -> int: - # """ - # Compare all GPATHE.nc files stored in the robot - # """ - # # Find reference gstore. By default the first file in the robot is used. - # ref_gstore = self._get_ref_abifile_from_basename(ref_basename) - - # exc_list = [] - # ierr = 0 - # for other_gstore in self.abifiles: - # if ref_gstore.filepath == other_gstore.filepath: - # continue - # print("Comparing: ", ref_gstore.basename, " with: ", other_gstore.basename) - # try: - # ierr += self._neq_two_gstores(ref_gstore, other_gstore, verbose) - # cprint("EQUAL", color="green") - # except Exception as exc: - # exc_list.append(str(exc)) - - # for exc in exc_list: - # cprint(exc, color="red") - - # return ierr - - #@staticmethod - #def _neq_two_gstores(gstore1: GstoreFile, gstore2: GstoreFile, verbose: int) -> int: - # """ - # Helper function to compare two GSTORE files. - # """ - # # These quantities must be the same to have a meaningfull comparison. - # aname_list = ["structure", "nsppol", "cplex", "nkbz", "nkibz", - # "nqbz", "nqibz", "completed", "kzone", "qzone", "kfilter", "gmode", - # "brange_spin", "erange_spin", "glob_spin_nq", "glob_nk_spin", - # ] - - # for aname in aname_list: - # self._compare_attr_name(aname, gstore1, gstore2) - - # # Now compare the gkq objects for each spin. - # ierr = 0 - # for spin in range(gstore1.nsppol): - # gqk1, gqk2 = gstore1.gqk_spin[spin], gstore2.gqk_spin[spin] - # ierr += gqk1.neq(gqk2, verbose) - - # return ierr + @add_fig_kwargs + def plot_g_qpath(self, which_g="avg", gmax_mev=250, ph_modes=None, + colormap="jet", **kwargs) -> Figure: + """ + Compare the g-matrix along a q-path. + + Args + which_g: "avg" to plot the symmetrized |g|, "raw" for unsymmetrized |g|."all" for both. + gmax_mev: Show results up to gmax in me + ph_modes: List of ph branch indices to show (start from 0). If None all modes are show. + colormap: Color map. Have a look at the colormaps here and decide which one you like: + http://matplotlib.sourceforge.net/examples/pylab_examples/show_colormaps.html + """ + nsppol, nq_path, natom3, eph_fix_wavec, eph_fix_korq = self.getattrs_alleq( + "nsppol", "nq_path", "natom3", "eph_fix_wavec", "eph_fix_korq" + ) + xs = np.arange(nq_path) + + nrows, ncols = 1, nsppol + ax_mat, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols, + sharex=False, sharey=False, squeeze=False) + cmap = plt.get_cmap(colormap) + + # TODO: Compute common band range. + band_range = None + ref_ifile= 0 + #q_label = r"$|q|^{%d}$" % with_qexp if with_qexp else "" + #g_units = "(meV)" if with_qexp == 0 else r"(meV $\AA^-{%s}$)" % with_qexp + + for spin in range(nsppol): + ax_cnt = 0 + ax = ax_mat[ax_cnt, spin] + + for ifile, gpath in enumerate(self.abifiles): + g_nuq_avg, g_nuq_raw = gpath.r.get_gnuq_average_spin(spin, band_range) + # Select ys according to which_g and multiply by facts_q + g_nuq = dict(avg=g_nuq_avg, raw=g_nuq_raw)[which_g] # * facts_q[None,:] + + for nu in range(natom3): + if ph_modes is not None and nu not in ph_modes: continue + color = cmap(nu / natom3) + if ifile == ref_ifile: + ax.scatter(xs, g_nuq[nu], color=color, label=f"{nu=}", marker="o") + gpath.phbands.decorate_ax(ax, units="meV") + #g_label = r"$|g^{\text{%s}}_{\mathbf{q}}|$ %s" % (which_g, q_label) + #set_grid_legend(ax, fontsize, ylabel="%s %s" % (g_label, g_units)) + else: + ax.plot(g_nuq[nu], color=color, label=f"{nu=}") + + #if gmax_mev is not None and with_qexp == 0: + if gmax_mev is not None: + set_axlims(ax, [0, gmax_mev], "y") + + return fig + + #@add_fig_kwargs + #def plot_g_kpath(self, **kwargs) --> Figure def yield_figs(self, **kwargs): # pragma: no cover """