Skip to content

Commit

Permalink
a couple of new plotting functions
Browse files Browse the repository at this point in the history
  • Loading branch information
harrisonlabollita committed Jan 4, 2024
1 parent 33c19cb commit 4227a47
Showing 1 changed file with 120 additions and 21 deletions.
141 changes: 120 additions & 21 deletions w2kplot/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@

from . import w2kplot_base_style, w2kplot_bands_style

from matplotlib.colors import LinearSegmentedColormap

def colormap(*args):
return LinearSegmentedColormap.from_list('', list(args), N=256)


# Bands class
class Bands(object):
Expand All @@ -50,9 +55,10 @@ def __init__(self,
self.eF_shift = eF_shift

if self.spaghetti is None:
try:
self.spaghetti = glob.glob("*.spaghetti_ene")[0]
except BaseException:
for ext in ["*.spaghetti_ene", "*.spaghettiup_ene", "*.spaghettidn_ene"]:
spaghetti = glob.glob(ext)
if len(spaghetti) > 0: self.spaghetti = spaghetti[0]
if self.spaghetti is None:
raise FileNotFoundError(
"Could not find a case.spaghetti_ene file in this directory.\nPlease provide a case.spaghetti_ene file")

Expand Down Expand Up @@ -94,7 +100,7 @@ def _get_dft_bands(self):
kpoints = np.unique(data[:, 3])
Ek = data[:, 4].reshape(int(len(data) / len(kpoints)), len(kpoints))
break
except BaseException:
except BaseException:
skiprows += 1

return kpoints, Ek
Expand All @@ -116,9 +122,7 @@ def _get_high_symmetry_path(self):
high_symmetry_labels.append(self._arg2latex(line.strip().split()[0]))
high_symmetry_points.append(il)
high_symmetry_points = [self.kpoints[ind] for ind in high_symmetry_points]
except BaseException:
raise Exception(
"An error occured when trying to parse the {} file".format(self.klist_band))
except BaseException: raise Exception("An error occured when trying to parse the {} file".format(self.klist_band))

return high_symmetry_points, high_symmetry_labels

Expand Down Expand Up @@ -146,23 +150,18 @@ def Up(case=None, **kwargs):
try:
spaghetti = glob.glob("*.spaghettiup_ene")[0]
except BaseException:
raise FileNotFoundError(
"Could not find a case.spaghettiup_ene file in this directory.\nPlease provide a case.spaghettiup_ene file")
raise FileNotFoundError("Could not find a case.spaghettiup_ene file in this directory.\nPlease provide a case.spaghettiup_ene file")
return Bands(spaghetti=spaghetti, **kwargs)
else:
return Bands(spaghetti=case + '.spaghettiup_ene', klist_band=case + '.klist_band', **kwargs)
else: return Bands(spaghetti=case + '.spaghettiup_ene', klist_band=case + '.klist_band', **kwargs)

@staticmethod
def Down(case=None, **kwargs):
if case is None:
try:
spaghetti = glob.glob("*.spaghettidn_ene")[0]
try: spaghetti = glob.glob("*.spaghettidn_ene")[0]
except BaseException:
raise FileNotFoundError(
"Could not find a case.spaghettidn_ene file in this directory.\nPlease provide a case.spaghetti_ene file")
raise FileNotFoundError("Could not find a case.spaghettidn_ene file in this directory.\nPlease provide a case.spaghetti_ene file")
return Bands(spaghetti=spaghetti, **kwargs)
else:
return Bands(spaghetti=case + '.spaghettidn_ene', klist_band=case + '.klist_band', **kwargs)
else: return Bands(spaghetti=case + '.spaghettidn_ene', klist_band=case + '.klist_band', **kwargs)

# bandstructure plotting
def __band_plot(figure, bands, *opt_list, **opt_dict):
Expand Down Expand Up @@ -203,9 +202,9 @@ def __band_plot(figure, bands, *opt_list, **opt_dict):


# band_plot functions
plt.style.use([w2kplot_base_style, w2kplot_base_style])
plt.style.use([w2kplot_base_style, w2kplot_bands_style])
def band_plot(bands, *opt_list, **opt_dict): __band_plot(plt, bands, *opt_list, **opt_dict)
plt.style.use([w2kplot_base_style, w2kplot_base_style])
plt.style.use([w2kplot_base_style, w2kplot_bands_style])
mpl.axes.Axes.band_plot = lambda self, bands, *opt_list, **opt_dict: __band_plot(self, bands, *opt_list, **opt_dict)


Expand All @@ -222,7 +221,9 @@ def __init__(self,
qtl: str = None,
eF: Union[str, float] = None,
struct: str = None,
eF_shift: float = 0) -> None:
eF_shift: float = 0,
alpha : float = 1,
) -> None:
"""
Initialize the FatBand data object. This class is a child of the Bands class.
Expand Down Expand Up @@ -265,6 +266,7 @@ def __init__(self,
self.orbitals = orbitals
self.weight = weight
self.colors = colors
self.alpha = alpha

# modify file names if case is available
struct = case + '.struct' if case else struct
Expand Down Expand Up @@ -397,7 +399,7 @@ def __fatband_plot(figure, fat_bands, *opt_list, **opt_dict):
else:
assert len(fat_bands.kpoints) == len(E), f"Did not parse file correctly! {len(fat_bands.kpoints), len(E)}"
assert len(E) == len(character), "Did not parse file correctly!"
figure.scatter(fat_bands.kpoints, E, character, fat_bands.colors[a][o], rasterized=True)
figure.scatter(fat_bands.kpoints, E, character, fat_bands.colors[a][o], rasterized=True, alpha=fat_bands.alpha)
E, character = [], []

# fatband plot functions
Expand All @@ -406,3 +408,100 @@ def fatband_plot(fat_bands,*opt_list, **opt_dict): __fatband_plot(plt, fat_bands

plt.style.use([w2kplot_base_style, w2kplot_base_style])
mpl.axes.Axes.fatband_plot = lambda self, fat_bands, *opt_list, **opt_dict: __fatband_plot(self, fat_bands, *opt_list, **opt_dict)

def __otherfatband_plot(figure, fat_bands, *opt_list, **opt_dict):
if isinstance(figure, types.ModuleType):
figure = figure.gca()

# plot the bands
__band_plot(figure, fat_bands, *opt_list, **opt_dict)

qtl_file = open(fat_bands.qtl)
qtl = qtl_file.readlines()
qtl_file.close()
start = [line + 1 for line in range(len(qtl)) if "BAND" in qtl[line]][0]
qtl = qtl[start:]

# plot the fatband character
for (a, at) in enumerate(fat_bands.atoms):
#for o in range(len(fat_bands.orbitals[a])):
#assert len(fat_bands.orbitals[a]) == 2
E, character, size = [], [], []
for line in qtl:
if 'BAND' not in line:
if line.split()[1] == str(at):
# wien2k interal units are Ry switch to eV
E.append((float(line.split()[0]) - fat_bands.eF) * fat_bands.Ry2eV - fat_bands.eF_shift)
# weight factor
enh = float(fat_bands.weight * fat_bands.structure.atoms[at - 1][1])
# qtl overlap
ovlap1 = (float(line.split()[int(fat_bands.orbitals[a][0]) + 1]))
ovlap2 = (float(line.split()[int(fat_bands.orbitals[a][1]) + 1]))
#ovlap3 = sum([float(line.split()[int(fat_bands.orbitals[a][k]) + 1]) for k in [2,3,4]])/3
#ovlap = ovlap1 if ovlap1 >= ovlap2 else -ovlap2
size.append(enh * (ovlap1+ovlap2))
#character.append(1 - np.exp(-1*np.array([ovlap1, ovlap3, ovlap2, ovlap1+ovlap2+ovlap3]).dot([10,2,10,1])))
#character.append(1 - np.exp(10*(ovlap2-ovlap1)))
character.append(ovlap1-ovlap2)
else:
assert len(fat_bands.kpoints) == len(E), f"Did not parse file correctly! {len(fat_bands.kpoints), len(E)}"
assert len(E) == len(character), "Did not parse file correctly!"
figure.scatter(fat_bands.kpoints, E, s=size,
c=np.asarray(character),
cmap = colormap('dodgerblue', 'crimson'),
rasterized=True
)
E, character, size = [], [], []

# fatband plot functions
plt.style.use([w2kplot_base_style, w2kplot_base_style])
def otherfatband_plot(fat_bands,*opt_list, **opt_dict): __fatband_plot(plt, fat_bands, *opt_list, **opt_dict)

plt.style.use([w2kplot_base_style, w2kplot_base_style])
mpl.axes.Axes.otherfatband_plot = lambda self, fat_bands, *opt_list, **opt_dict: __otherfatband_plot(self, fat_bands, *opt_list, **opt_dict)


# bandstructure plotting
def __spectral_plot(figure, bands, *opt_list, **opt_dict):

if isinstance(figure, types.ModuleType):
figure = figure.gca()

try:
# new version of matplotlib
grid_spec = figure.get_subplotspec()
is_first_col = grid_spec.is_first_col()
is_last_row = grid_spec.is_last_row()
except BaseException:
# old version of matplotlib
is_first_col = figure.is_first_col()
is_last_row = figure.is_last_row()

figure.tick_params(axis='both',which='minor',length=3.5,width=0.5,labelsize=12,bottom=False,top=False)
figure.tick_params(axis='both', which='major',length=7, width=0.5, labelsize=12, bottom=False, top=False)

# build up spectral function
kmesh = np.linspace(bands.kpoints.min(), bands.kpoints.max(), len(bands.kpoints))
omega = np.linspace(bands.Ek.min(), bands.Ek.max(),5000)

Akw = np.zeros((len(kmesh), len(omega)))
for ik in range(len(kmesh)): Akw[ik,:] = (1/(omega[:,None] - bands.Ek[:,ik] +0.001j)).sum(axis=1).imag*(-1/np.pi)

kgrid, ogrid = np.meshgrid(kmesh, omega)
figure.pcolormesh(kgrid, ogrid, Akw.T, **opt_dict)

# decorate the figure from here
figure.set_xticks(bands.high_symmetry_points)
if is_last_row: figure.set_xticklabels(bands.high_symmetry_labels)
for k in bands.high_symmetry_points: figure.axvline(k, color="w", lw=1, ls='dotted')
figure.axhline(0.0, color="w", ls='dotted', lw=1)
# if we are the first column we will always add the ylabel
if is_first_col: figure.set_ylabel(r"$\varepsilon - \varepsilon_{\mathrm{F}}$ (eV)")
figure.set_ylim(-2, 2)
figure.set_xlim(bands.high_symmetry_points[0], bands.high_symmetry_points[-1])

plt.style.use([w2kplot_base_style, w2kplot_bands_style])
def spectral_plot(bands, *opt_list, **opt_dict): __spectral_plot(plt, bands, *opt_list, **opt_dict)

plt.style.use([w2kplot_base_style, w2kplot_bands_style])
mpl.axes.Axes.spectral_plot = lambda self, bands, *opt_list, **opt_dict: __spectral_plot(self, bands, *opt_list, **opt_dict)

0 comments on commit 4227a47

Please sign in to comment.