Skip to content

Commit

Permalink
Fix bug in plot_ank_with_ebands due to inonsistency in the order of b…
Browse files Browse the repository at this point in the history
…z2ibz and bz_kpoints
  • Loading branch information
gmatteo committed Nov 23, 2024
1 parent 218444e commit 57a077d
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 63 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/gh-pages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Deploy documentation with GitHub Pages dependencies preinstalled

on:
push:
branches: ["develop"]
branches: ["develop", "master"]
workflow_dispatch: # enable manual workflow execution

# Set permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages
Expand Down
48 changes: 33 additions & 15 deletions abipy/core/kpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,58 +189,65 @@ def kmesh_from_mpdivs(mpdivs, shifts, pbc=False, order="bz"):
rc1 = rc_list(mpdivs[1], shift[1], pbc=pbc, order=order)
rc2 = rc_list(mpdivs[2], shift[2], pbc=pbc, order=order)

# NB: z is the fastest index in (x, y, z)
for kxyz in product(rc0, rc1, rc2):
#print("kxyz", kxyz)
kbz.append(kxyz)

#import sys; sys.exit(1)
return np.array(kbz)


def map_grid2ibz(structure, ibz, ngkpt, has_timrev, pbc=False):
def map_grid2ibz(structure, ibz, ngkpt, shifts, has_timrev, pbc=False):
"""
Compute the correspondence between a *grid* of k-points in the *unit cell*
associated to the ``ngkpt`` mesh and the corresponding points in the IBZ.
Requires structure with Abinit symmetries.
This routine is mainly used to symmetrize eigenvalues in the unit cell
e.g. to write BXSF files for electronic isosurfaces.
Args:
structure: Structure with (Abinit) symmetry operations.
ibz: [*, 3] array with reduced coordinates in the in the IBZ.
structure: Structure with Abinit symmetry operations.
ibz: [*, 3] array with the reduced coordinates in the IBZ.
ngkpt: Mesh divisions.
shifts:
has_timrev: True if time-reversal can be used.
pbc: True if the mesh should contain the periodic images (closed mesh).
Returns:
bz2ibz: 1d array with BZ --> IBZ mapping
"""
ngkpt = np.asarray(ngkpt, dtype=int)

# Extract (FM) symmetry operations in reciprocal space.
abispg = structure.abi_spacegroup
if abispg is None:
if (abispg := structure.abi_spacegroup) is None:
raise ValueError("Structure does not contain Abinit spacegroup info!")

ngkpt = np.asarray(ngkpt, dtype=int)
shifts = np.reshape(shifts, (-1, 3))

# TODO: Handle multiple shifts
if np.any(np.abs(shifts) > 0):
raise ValueError("The k-mesh should be gamma-centered!")
if len(shifts) > 1:
raise ValueError("Multiple shifts are not supported!")

# Extract rotations in reciprocal space (FM part).
symrec_fm = [o.rot_g for o in abispg.fm_symmops]

# Compute TS k_ibz.
bzgrid2ibz = -np.ones(ngkpt, dtype=int)

for ik_ibz, kibz in enumerate(ibz):
gp_ibz = np.array(np.rint(kibz * ngkpt), dtype=int)
for rot in symrec_fm:
# Compute S k_ibz.
rot_gp = np.matmul(rot, gp_ibz)
gp_bz = rot_gp % ngkpt
bzgrid2ibz[gp_bz[0], gp_bz[1], gp_bz[2]] = ik_ibz
if has_timrev:
# Compute TS k_ibz.
gp_bz = (-rot_gp) % ngkpt
bzgrid2ibz[gp_bz[0], gp_bz[1], gp_bz[2]] = ik_ibz

if pbc:
# Add periodic replicas.
bzgrid2ibz = add_periodic_replicas(bzgrid2ibz)

# Consistency check.
if np.any(bzgrid2ibz == -1):
#for ik_bz, ik_ibz in enumerate(self.bzgrid2ibz): print(ik_bz, ">>>", ik_ibz)
msg = " Found %s/%s invalid entries in bzgrid2ibz array\n" % ((bzgrid2ibz == -1).sum(), bzgrid2ibz.size)
Expand All @@ -249,8 +256,20 @@ def map_grid2ibz(structure, ibz, ngkpt, has_timrev, pbc=False):
msg += f" {abispg=}\n"
raise ValueError(msg)

bz2ibz = bzgrid2ibz.flatten()
return bz2ibz
# Generate k-points using numpy's meshgrid and stack for efficiency
nx, ny, nz = ngkpt[0], ngkpt[1], ngkpt[2]
if pbc:
nx, ny, nz = nx + 1, ny + 1, nz + 1

kx = (np.arange(nx) + shifts[0,0]) / ngkpt[0]
ky = (np.arange(ny) + shifts[0,1]) / ngkpt[1]
kz = (np.arange(nz) + shifts[0,2]) / ngkpt[2]

# Create the 3D grid of points
kx, ky, kz = np.meshgrid(kx, ky, kz, indexing="ij")
bz_kpoints = np.stack((kx.ravel(), ky.ravel(), kz.ravel()), axis=-1)

return bzgrid2ibz.flatten(), bz_kpoints


def has_timrev_from_kptopt(kptopt):
Expand Down Expand Up @@ -561,7 +580,6 @@ def as_kpoints(obj, lattice, weights=None, names=None):
if names is None: names = nk * [None]
return [Kpoint(rc, lattice, weight=w, name=l) for (rc, w, l) in zip(obj, weights, names)]


raise ValueError(f"{ndim=} > 2 is not supported!")


Expand Down
3 changes: 2 additions & 1 deletion abipy/core/tests/test_kpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,8 @@ def setUp(self):

def test_map_grid2ibz(self):
"""Testing map_grid2ibz."""
bz2ibz = map_grid2ibz(self.mgb2, self.kibz, self.ngkpt, self.has_timrev, pbc=False)
shifts = [0, 0, 0]
bz2ibz, bz_kpoints = map_grid2ibz(self.mgb2, self.kibz, self.ngkpt, shifts, self.has_timrev, pbc=False)

bz = []
nx, ny, nz = self.ngkpt
Expand Down
2 changes: 1 addition & 1 deletion abipy/dfpt/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pymatgen.phonon.dos import CompletePhononDos as PmgCompletePhononDos, PhononDos as PmgPhononDos
from abipy.core.func1d import Function1D
from abipy.core.mixins import AbinitNcFile, Has_Structure, Has_PhononBands, NotebookWriter
from abipy.core.kpoints import Kpoint, Kpath, KpointList, kmesh_from_mpdivs, map_grid2ibz
from abipy.core.kpoints import Kpoint, Kpath, KpointList, kmesh_from_mpdivs
from abipy.core.structure import Structure
from abipy.abio.robots import Robot
from abipy.iotools import ETSF_Reader
Expand Down
53 changes: 25 additions & 28 deletions abipy/electrons/ebands.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,24 +965,21 @@ def has_timrev(self) -> bool:

def get_bz2ibz_bz_points(self, require_gamma_centered=False):
"""
Return named tupled with mapping bz2ibz and the list of k-points in the BZ.
Return named tupled with the mapping bz2ibz and the list of k-points in the BZ.
Args:
require_gamma_centered: True if the k-mesh should be Gamma-centered
"""
err_msg = self.isnot_ibz_sampling(require_gamma_centered=require_gamma_centered)
if err_msg:
if err_msg := self.isnot_ibz_sampling(require_gamma_centered=require_gamma_centered):
raise TypeError(err_msg)

if not self.kpoints.is_mpmesh:
raise ValueError("Only Monkhorst-Pack meshes are supported")
raise ValueError("Only Monkhorst-Pack meshes are supported.")

ngkpt, shifts = self.kpoints.mpdivs_shifts
#print(f"{ngkpt = }, {shifts =}")
# TODO: Handle shifts

bz2ibz = map_grid2ibz(self.structure, self.kpoints.frac_coords, ngkpt, self.has_timrev)
bz_kpoints = kmesh_from_mpdivs(ngkpt, shifts)
# Use symmetries to map self.kpoints.frac_coords to ngkpt mesh.
bz2ibz, bz_kpoints = map_grid2ibz(self.structure, self.kpoints.frac_coords, ngkpt, shifts, self.has_timrev)

return dict2namedtuple(bz2ibz=bz2ibz, ngkpt=ngkpt, shifts=shifts, bz_kpoints=bz_kpoints)

Expand Down Expand Up @@ -1616,7 +1613,7 @@ def direct_gaps(self) -> List[ElectronTransition]:
gaps = np.array(gaps)
kinds = np.where(gaps == gaps.min())[0]
kdir = kinds[0]
all_kinds = list(zip(kinds, kinds))
all_kinds = list(zip(kinds, kinds, strict=True))
#kdir = kinds[len(kinds) // 2]
#kdir = np.array(gaps).argmin()
dirgaps[spin] = ElectronTransition(self.homo_sk(spin, kdir), self.lumo_sk(spin, kdir), all_kinds=all_kinds)
Expand Down Expand Up @@ -2461,7 +2458,7 @@ def plot_split(self, ylims_list: list,
# Adjust space between Axes
fig.subplots_adjust(hspace=hspace)

for ix, (ax, ylims) in enumerate(zip(ax_list, ylims_list)):
for ix, (ax, ylims) in enumerate(zip(ax_list, ylims_list, strict=True)):
# Plot the same data on all Axes
self.plot(ax=ax, show=False, **kwargs)
# Zoom-in / limit the view to different portions of the data.
Expand Down Expand Up @@ -2926,7 +2923,7 @@ def plot_lws_vs_e0(self, ax=None, e0="fermie", function=lambda x: x, exchange_xy
xlabel = r"$\epsilon_{KS}-\epsilon_F\;(eV)$"

# DSU sort to get lw(e) with sorted energies.
e0mesh, lws = zip(*sorted(zip(self.eigens.flat, self.linewidths.flat), key=lambda t: t[0]))
e0mesh, lws = zip(*sorted(zip(self.eigens.flat, self.linewidths.flat), key=lambda t: t[0]), strict=True)
e0 = self.get_e0(e0)
e0mesh = np.array(e0mesh) - e0

Expand Down Expand Up @@ -3002,7 +2999,7 @@ def w(s):

kticks, klabels = self._make_ticks_and_labels(klabels=None)
w('@xaxis tick spec %d' % len(kticks))
for ik, (ktick, klabel) in enumerate(zip(kticks, klabels)):
for ik, (ktick, klabel) in enumerate(zip(kticks, klabels, strict=True)):
w('@xaxis tick major %d, %d' % (ik, ktick))
w('@xaxis ticklabel %d, "%s"' % (ik, klabel))

Expand Down Expand Up @@ -3339,7 +3336,7 @@ def interpolate(self, lpratio=5, knames=None, vertices_names=None, line_density=
if (abispg := self.structure.abi_spacegroup) is None:
abispg = self.structure.spgset_abi_spacegroup(has_timerev=self.has_timrev)

fm_symrel = [s for (s, afm) in zip(abispg.symrel, abispg.symafm) if afm == 1]
fm_symrel = [s for (s, afm) in zip(abispg.symrel, abispg.symafm, strict=True) if afm == 1]

if self.nband > self.nelect and self.nband > 20 and bstart == 0 and bstop is None:
cprint("Bands object contains nband %s with nelect %s. You may want to use bstart, bstop to select bands." % (
Expand Down Expand Up @@ -3728,7 +3725,7 @@ def combiplotly(self, e0="fermie", ylims=None, width_ratios=(2, 1), fontsize=12,
if any(nk != nkpt_list[0] for nk in nkpt_list):
cprint("WARNING: Bands have different number of k-points:\n%s" % str(nkpt_list), "yellow")

for (label, ebands), lineopt in zip(self.ebands_dict.items(), self.iter_lineopt_plotly()):
for (label, ebands), lineopt in zip(self.ebands_dict.items(), self.iter_lineopt_plotly(), strict=True):
i += 1
if linestyle_dict is not None and label in linestyle_dict:
my_kwargs.update(linestyle_dict[label])
Expand Down Expand Up @@ -3822,7 +3819,7 @@ def gridplot(self, e0="fermie", with_dos=True, with_gaps=False, max_phfreq=None,
# don't show the last ax if numeb is odd.
if numeb % ncols != 0: ax_list[-1].axis("off")

for i, (ebands, ax) in enumerate(zip(ebands_list, ax_list)):
for i, (ebands, ax) in enumerate(zip(ebands_list, ax_list), strict=True):
irow, icol = divmod(i, ncols)
ebands.plot(ax=ax, e0=e0, with_gaps=with_gaps, max_phfreq=max_phfreq, fontsize=fontsize, show=False)
set_axlims(ax, ylims, "y")
Expand All @@ -3838,7 +3835,7 @@ def gridplot(self, e0="fermie", with_dos=True, with_gaps=False, max_phfreq=None,
fig = plt.figure()
gspec = GridSpec(nrows, ncols)

for i, (ebands, edos) in enumerate(zip(ebands_list, edos_list)):
for i, (ebands, edos) in enumerate(zip(ebands_list, edos_list), strict=True):
subgrid = GridSpecFromSubplotSpec(1, 2, subplot_spec=gspec[i], width_ratios=[2, 1], wspace=0.05)
# Get axes and align bands and DOS.
ax0 = plt.subplot(subgrid[0])
Expand Down Expand Up @@ -3934,7 +3931,7 @@ def gridplotly(self, e0="fermie", with_dos=True, with_gaps=False, max_phfreq=Non
column_widths=[2, 1]*2, horizontal_spacing=0.02, sharex=False, sharey=True)
# all sub_fig in the same row will share y

for i, (ebands, edos) in enumerate(zip(ebands_list, edos_list)):
for i, (ebands, edos) in enumerate(zip(ebands_list, edos_list), strict=True):
# Align bands and DOS.
irow, icol = divmod(i, 2)
band_rcd = PlotlyRowColDesc(irow, icol * 2, nrows, ncols)
Expand Down Expand Up @@ -3984,7 +3981,7 @@ def boxplot(self, e0="fermie", brange=None, swarm=False, fontsize=8, **kwargs) -
# don't show the last ax if num_plots is odd.
if num_plots % ncols != 0: ax_list[-1].axis("off")

for (label, ebands), ax in zip(self.ebands_dict.items(), ax_list):
for (label, ebands), ax in zip(self.ebands_dict.items(), ax_list, strict=True):
ebands.boxplot(ax=ax, brange=brange, show=False)
ax.set_title(label, fontsize=fontsize)

Expand Down Expand Up @@ -4032,7 +4029,7 @@ def combiboxplot(self, e0="fermie", brange=None, swarm=False, ax=None, **kwargs)
if ax is not None:
raise NotImplementedError("ax == None not implemented when nsppol==2")
fig, ax_list = plt.subplots(nrows=2, ncols=1, sharex=True, squeeze=False)
for spin, ax in zip(range(2), ax_list.ravel()):
for spin, ax in zip(range(2), ax_list.ravel(), strict=True):
ax.grid(True)
data_spin = data[data["spin"] == spin]
sns.boxplot(x="band", y="eig", data=data_spin, hue="label", ax=ax, **kwargs)
Expand Down Expand Up @@ -4157,7 +4154,7 @@ def animate(self, e0="fermie", interval=500, savefile=None, width_ratios=(2, 1),
ax1.yaxis.set_ticks_position("right")
ax1.yaxis.set_label_position("right")

for i, (ebands, edos) in enumerate(zip(ebands_list, edos_list)):
for i, (ebands, edos) in enumerate(zip(ebands_list, edos_list), strict=True):
# Define the zero of energy to align bands and dos
mye0 = ebands.get_e0(e0) if e0 != "edos_fermie" else edos.fermie
ebands_lines = ebands.plot_ax(ax0, mye0, **plotax_kwargs)
Expand Down Expand Up @@ -4435,7 +4432,7 @@ def as_edos(cls, obj: Any, edos_kwargs: dict) -> ElectronDos:
def __eq__(self, other):
if other is None: return False
if self.nsppol != other.nsppol: return False
for f1, f2 in zip(self.spin_dos, other.spin_dos):
for f1, f2 in zip(self.spin_dos, other.spin_dos, strict=True):
if f1 != f2: return False
return True

Expand Down Expand Up @@ -4829,7 +4826,7 @@ def to_pymatgen(self):
Return a pymatgen DOS object from an Abipy |ElectronDos| object.
"""
from pymatgen.electronic_structure.dos import Dos
den = {s: d.values for d, s in zip(self.spin_dos, [PmgSpin.up, PmgSpin.down])}
den = {s: d.values for d, s in zip(self.spin_dos, [PmgSpin.up, PmgSpin.down], strict=True)}
pmg_dos = Dos(energies=self.spin_dos[0].mesh, densities=den, efermi=self.fermie)

return pmg_dos
Expand Down Expand Up @@ -4911,7 +4908,7 @@ def combiplot(self, what_list="dos", spin_mode="automatic", e0="fermie",
ax_list = ax_list.ravel()

can_use_basename = self._can_use_basenames_as_labels()
for i, (what, ax) in enumerate(zip(what_list, ax_list)):
for i, (what, ax) in enumerate(zip(what_list, ax_list, strict=True)):
for label, edos in self.edoses_dict.items():
if can_use_basename:
label = os.path.basename(label)
Expand Down Expand Up @@ -5080,7 +5077,7 @@ def gridplot(self, what="dos", spin_mode="automatic", e0="fermie",
# don't show the last ax if numeb is odd.
if numeb % ncols != 0: ax_list[-1].axis("off")

for i, ((label, edos), ax) in enumerate(zip(self.edoses_dict.items(), ax_list)):
for i, ((label, edos), ax) in enumerate(zip(self.edoses_dict.items(), ax_list, strict=True)):
irow, icol = divmod(i, ncols)

# Here I handle spin and spin_mode.
Expand Down Expand Up @@ -5285,7 +5282,7 @@ def __init__(self, structure, ibz, has_timrev, eigens, fermie):

# Xcrysden requires points in the unit cell (C-order)
# and the mesh must include the periodic images hence pbc=True.
self.uc2ibz = map_grid2ibz(self.structure, self.ibz.frac_coords, mpdivs, self.has_timrev, pbc=True)
self.uc2ibz, _ = map_grid2ibz(self.structure, self.ibz.frac_coords, mpdivs, shifts, self.has_timrev, pbc=True)
self.mpdivs = mpdivs
self.kdivs = mpdivs + 1
self.spacing = 1.0 / mpdivs
Expand Down Expand Up @@ -5832,7 +5829,7 @@ def get_xy(item, spin, all_xvals, all_abifiles):
groups = self.group_and_sortby(hue, sortby)

marker_spin = {0: "^", 1: "v"}
for i, (ax, item) in enumerate(zip(ax_list, items)):
for i, (ax, item) in enumerate(zip(ax_list, items), strict=True):
for spin in range(max_nsppol):
if hue is None:
# Extract data.
Expand Down Expand Up @@ -5907,7 +5904,7 @@ def gridplot_with_hue(self, hue, ylims=None, fontsize=8,
ax_list = ax_list.ravel()
e0 = "fermie" # Each ebands is aligned with respect to its Fermi energy.

for ax, grp in zip(ax_list, groups):
for ax, grp in zip(ax_list, groups, strict=True):
ax.grid(True)
ebands_list = [abifile.ebands for abifile in grp.abifiles]
ax.set_title("%s = %s" % (self._get_label(hue), grp.hvalue), fontsize=fontsize)
Expand All @@ -5916,7 +5913,7 @@ def gridplot_with_hue(self, hue, ylims=None, fontsize=8,
if any(nk != nkpt_list[0] for nk in nkpt_list):
cprint("WARNING: Bands have different number of k-points:\n%s" % str(nkpt_list), "yellow")

for i, (ebands, lineopts) in enumerate(zip(ebands_list, self.iter_lineopt())):
for i, (ebands, lineopts) in enumerate(zip(ebands_list, self.iter_lineopt(), strict=True)):
# Plot all branches with lineopts and set the label of the last line produced.
ebands.plot_ax(ax, e0, **lineopts)
ax.lines[-1].set_label("%s" % grp.labels[i])
Expand Down
Loading

0 comments on commit 57a077d

Please sign in to comment.