Skip to content

Commit

Permalink
Pass interp_method to plot_ank_with_ebands
Browse files Browse the repository at this point in the history
  • Loading branch information
gmatteo committed Nov 21, 2024
1 parent df88020 commit 7181567
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 18 deletions.
7 changes: 4 additions & 3 deletions abipy/electrons/fatbands.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def plot_fatbands_siteview(self, e0="fermie", view="inequivalent", fact=1.0, fon
return fig

@add_fig_kwargs
def plot_fatbands_lview(self, e0="fermie", fact=1.0, ax_mat=None, lmax=None,
def plot_fatbands_lview(self, e0="fermie", fact=1.0, ax_mat=None, lmin=0, lmax=None,
ylims=None, blist=None, fontsize=12, **kwargs) -> Figure:
"""
Plot the electronic fatbands grouped by L with matplotlib.
Expand All @@ -537,6 +537,7 @@ def plot_fatbands_lview(self, e0="fermie", fact=1.0, ax_mat=None, lmax=None,
- None: Don't shift energies, equivalent to ``e0 = 0``
fact: float used to scale the stripe size.
ax_mat: Matrix of axes, if None a new figure is produced.
lmin: Minimum L included in plot.
lmax: Maximum L included in plot. None means full set available on file.
ylims: Set the data limits for the y-axis. Accept tuple e.g. ``(left, right)``
or scalar e.g. ``left``. If left (right) is None, default values are used
Expand All @@ -547,7 +548,7 @@ def plot_fatbands_lview(self, e0="fermie", fact=1.0, ax_mat=None, lmax=None,
"""
mylsize = self.lsize if lmax is None else lmax + 1
# Build or get grid with (nsppol, mylsize) axis.
nrows, ncols = self.nsppol, mylsize
nrows, ncols = self.nsppol, mylsize - lmin
ax_mat, fig, plt = get_axarray_fig_plt(ax_mat, nrows=nrows, ncols=ncols,
sharex=True, sharey=True, squeeze=False)
ax_mat = np.reshape(ax_mat, (nrows, ncols))
Expand All @@ -558,7 +559,7 @@ def plot_fatbands_lview(self, e0="fermie", fact=1.0, ax_mat=None, lmax=None,
mybands = range(ebands.mband) if blist is None else blist

for spin in range(self.nsppol):
for l in range(mylsize):
for l in range(lmin, mylsize):
ax = ax_mat[spin, l]
ebands.plot_ax(ax, e0, spin=spin, **self.eb_plotax_kwargs(spin))
title = "%s, %s" % (self.l2tex[l], self.spin2tex[spin]) if self.nsppol == 2 else "%s" % self.l2tex[l]
Expand Down
51 changes: 37 additions & 14 deletions abipy/eph/varpeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,22 +436,31 @@ def insert_b_inbox(self, fill_value=None) -> tuple:

return b_data, ngqpt, shifts

def get_a2_interpolator_state(self) -> BzRegularGridInterpolator:
def get_a2_interpolator_state(self, interp_method) -> BzRegularGridInterpolator:
"""
Build and return an interpolator for |A_nk|^2 for each polaronic state.
Args:
interp_method: The method of interpolation to perform. Supported are “linear”, “nearest”,
“slinear”, “cubic”, “quintic” and “pchip”.
"""
a_data, ngkpt, shifts = self.insert_a_inbox()

return [BzRegularGridInterpolator(self.structure, shifts, np.abs(a_data[pstate])**2, method="linear")
return [BzRegularGridInterpolator(self.structure, shifts, np.abs(a_data[pstate])**2, method=interp_method)
for pstate in range(self.nstates)]

def get_b2_interpolator_state(self) -> BzRegularGridInterpolator:
def get_b2_interpolator_state(self, interp_method) -> BzRegularGridInterpolator:
"""
Build and return an interpolator for |B_qnu|^2 for each polaronic state.
Args:
interp_method: The method of interpolation to perform. Supported are “linear”, “nearest”,
“slinear”, “cubic”, “quintic” and “pchip”.
"""
b_data, ngqpt, shifts = self.insert_b_inbox()

return [BzRegularGridInterpolator(self.structure, shifts, np.abs(b_data[pstate])**2, method="linear")
return [BzRegularGridInterpolator(self.structure, shifts, np.abs(b_data[pstate])**2, method=interp_method)
for pstate in range(self.nstates)]

def write_a2_bxsf(self, filepath: PathLike, fill_value=0.0) -> None:
Expand Down Expand Up @@ -562,8 +571,9 @@ def plot_scf_cycle(self, ax_mat=None, fontsize=8, **kwargs) -> Figure:

@add_fig_kwargs
def plot_ank_with_ebands(self, ebands_kpath,
ebands_kmesh=None, lpratio: int=5, method="gaussian", step: float=0.05, width: float=0.1,
nksmall: int=20, normalize: bool=False, with_title=True,
ebands_kmesh=None, lpratio: int=5,
with_ibz_a2dos=True, method="gaussian", step: float=0.05, width: float=0.1,
nksmall: int=20, normalize: bool=False, with_title=True, interp_method="linear",
ax_mat=None, ylims=None, scale=10, marker_color="gold", fontsize=12, **kwargs) -> Figure:
"""
Plot electron bands with markers with size proportional to |A_nk|^2.
Expand All @@ -575,10 +585,12 @@ def plot_ank_with_ebands(self, ebands_kpath,
normalize: Rescale the two DOS to plot them on the same scale.
lpratio: Ratio between the number of star functions and the number of ab-initio k-points.
The default should be OK in many systems, larger values may be required for accurate derivatives.
with_ibz_a2dos: True if A2_IBZ(E) should be computed.
method: Integration scheme for DOS
step: Energy step (eV) of the linear mesh for DOS computation.
width: Standard deviation (eV) of the gaussian for DOS computation.
with_title: True to add title with chemical formula and gaps.
interp_method: Interpolation method.
ax_mat: List of |matplotlib-Axes| or None if a new figure should be created.
ylims: Set the data limits for the y-axis. Accept tuple e.g. ``(left, right)``
or scalar e.g. ``left``. If left (right) is None, default values are used.
Expand All @@ -591,14 +603,14 @@ def plot_ank_with_ebands(self, ebands_kpath,
ax_mat, fig, plt = get_axarray_fig_plt(ax_mat, nrows=nrows, ncols=ncols,
sharex=False, sharey=True, squeeze=False, gridspec_kw=gridspec_kw)
# Get interpolators for A_nk
a2_interp_state = self.get_a2_interpolator_state()
a2_interp_state = self.get_a2_interpolator_state(interp_method)

# DEBUG SECTION
#ref_kn = np.abs(self.a_kn) ** 2
#ref_akn = np.abs(self.a_kn) ** 2
#for ik, kpoint in enumerate(self.kpoints):
# interp = a2_interp_state[0].eval_kpoint(kpoint)
# print("MAX (A2 ref - A2 interp) at qpoint", kpoint)
# print((np.abs(ref_kn[ik] - interp)).max())
# print((np.abs(ref_akn[ik] - interp)).max())

df = self.get_final_results_df()

Expand Down Expand Up @@ -652,11 +664,21 @@ def plot_ank_with_ebands(self, ebands_kpath,
for pstate in range(self.nstates):
# Compute A^2(E) DOS with A_nk in the full BZ
ank_dos = np.zeros(len(edos_mesh))
for ik_ibz, kpoint in zip(kmesh.bz2ibz, kmesh.bz_kpoints):
#a2_max, kpoint_max, band_max, tol = None, None, None, 0.4
for ik_ibz, kpoint in zip(kmesh.bz2ibz, kmesh.bz_kpoints, strict=True):
enes_n = ebands_kmesh.eigens[self.spin, ik_ibz, self.bstart:self.bstop]
a2_n = a2_interp_state[pstate].eval_kpoint(kpoint)
for e, a2 in zip(enes_n, a2_n):
for band, (e, a2) in enumerate(zip(enes_n, a2_n, strict=True)):
ank_dos += a2 * gaussian(edos_mesh, width, center=e-e0)
#print(float(e-e0), a2)
#if a2_max is None and np.any(np.abs(kpoint) > tol):
# a2_max, kpoint_max, band_max = a2, kpoint, band
#if a2_max is not None and a2 > a2_max and np.any(np.abs(kpoint) > tol):
# a2_max, kpoint_max, band_max = a2, kpoint, band

#band_max += self.bstart
#print(f"For {pstate=}, {a2_max=}, {kpoint_max=}, {band_max=}")

ank_dos /= np.product(kmesh.ngkpt)
ank_dos = Function1D(edos_mesh, ank_dos)
print(f"For {pstate=}, A^2(E) integrates to:", ank_dos.integral_value, " Ideally, it should be 1.")
Expand All @@ -668,13 +690,13 @@ def plot_ank_with_ebands(self, ebands_kpath,

# Computes A2(E) using only k-points in the IBZ. This is just for testing.
# A2_IBZ(E) should be equal to A2(E) only if A_nk fullfills the lattice symmetries. See notes above.
with_ibz_a2dos = True
if with_ibz_a2dos:
ank_dos = np.zeros(len(edos_mesh))
for ik_ibz, kpoint in enumerate(ebands_kmesh.kpoints):
weight = kpoint.weight
enes_n = ebands_kmesh.eigens[self.spin, ik_ibz, self.bstart:self.bstop]
for e, a2 in zip(enes_n, a2_interp_state[pstate].eval_kpoint(kpoint), strict=True):
#print(float(e-e0), a2)
ank_dos += weight * a2 * gaussian(edos_mesh, width, center=e-e0)
ank_dos = Function1D(edos_mesh, ank_dos)
print(f"For {pstate=}, A2_IBZ(E) integrates to:", ank_dos.integral_value, " Ideally, it should be 1.")
Expand Down Expand Up @@ -725,7 +747,7 @@ def plot_bqnu_with_ddb(self, ddb, with_phdos=True, anaddb_kwargs=None, **kwargs)
@add_fig_kwargs
def plot_bqnu_with_phbands(self, phbands_qpath,
phdos_file=None, ddb=None, width=0.001, normalize: bool=True,
verbose=0, anaddb_kwargs=None, with_title=True,
verbose=0, anaddb_kwargs=None, with_title=True, interp_method="linear",
ax_mat=None, scale=10, marker_color="gold", fontsize=12, **kwargs) -> Figure:
"""
Plot phonon energies with markers with size proportional to |B_qnu|^2.
Expand All @@ -739,6 +761,7 @@ def plot_bqnu_with_phbands(self, phbands_qpath,
verbose:
anaddb_kwargs: Optional arguments passed to anaddb.
with_title: True to add title with chemical formula and gaps.
interp_method: Interpolation method.
ax_mat: List of |matplotlib-Axes| or None if a new figure should be created.
scale: Scaling factor for |B_qnu|^2.
marker_color: Color for markers.
Expand All @@ -755,7 +778,7 @@ def plot_bqnu_with_phbands(self, phbands_qpath,
phbands_qpath = PhononBands.as_phbands(phbands_qpath)

# Get interpolators for B_qnu
b2_interp_state = self.get_b2_interpolator_state()
b2_interp_state = self.get_b2_interpolator_state(interp_method)

for pstate in range(self.nstates):
x, y, s = [], [], []
Expand Down
4 changes: 3 additions & 1 deletion abipy/tools/numtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,9 @@ def __init__(self, structure, shifts, datak, add_replicas=True, **kwargs):
shifts: Shift of the mesh.
add_replicas: If True, data is padded with redundant data points.
in order to have a periodic 3D array of shape=[ndat, nx+1, ny+1, nz+1].
kwargs: Extra arguments are passed to RegularGridInterpolator.
kwargs: Extra arguments are passed to RegularGridInterpolator e.g.: method
The method of interpolation to perform. Supported are “linear”, “nearest”,
“slinear”, “cubic”, “quintic” and “pchip”.
"""
self.structure = structure
self.shifts = shifts
Expand Down

0 comments on commit 7181567

Please sign in to comment.