Skip to content

Commit

Permalink
Up
Browse files Browse the repository at this point in the history
  • Loading branch information
RemiHelleboid committed Mar 1, 2023
1 parent 08e61cf commit 9c5b75c
Show file tree
Hide file tree
Showing 3 changed files with 15,044 additions and 50 deletions.
245 changes: 196 additions & 49 deletions python/EPM/EPM.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,19 @@
import numpy as np
import scipy.linalg as la
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
from argparse import ArgumentParser
from matplotlib.lines import Line2D
import matplotlib as mpl

try:
plt.style.use(['science', 'high-vis', 'grid'])
except Exception:
None
plt.style.use(['seaborn-paper'])

mpl.rcParams['figure.figsize'] = [3.5, 2.8]


import itertools

Expand Down Expand Up @@ -188,6 +201,174 @@ def band_structure(path):
return np.stack(bands, axis=-1)


def eigen_states(k: np.ndarray, basis_G: np.ndarray=None) -> np.ndarray:
"""Generate the eigen states.
Args:
k (np.ndarray): The wave vector.
basis_G (np.ndarray): The basis vector.
Returns:
np.ndarray: The eigen states.
"""
if basis_G is None:
basis_G = EPM_BASIS
HamiltonianMatrix = hamiltonian(k, basis_G)
# Get eigenvalues and eigenvectors of the Hamiltonian matrix
eigen_values, eigen_vectors = la.eig(HamiltonianMatrix)
# Sort the eigenvalues and eigenvectors
idx = eigen_values.argsort()
eigen_values = eigen_values[idx]
eigen_vectors = eigen_vectors[:, idx]
return eigen_values, eigen_vectors

def plot_eigen_states(k: np.ndarray, basis_G: np.ndarray=None, n_states: int=8) -> None:
"""Plot the eigen states.
Args:
k (np.ndarray): The wave vector.
basis_G (np.ndarray): The basis vector.
n_states (int): The number of states to plot.
"""
if basis_G is None:
basis_G = EPM_BASIS
eigen_values, eigen_vectors = eigen_states(k, basis_G)
norms = np.linalg.norm(eigen_vectors, axis=0)
N = len(basis_G)
Ns = 10
# Plot the fourier factors colors is shade of blue
cmap = plt.cm.get_cmap('Blues')
colors = [cmap(i) for i in np.linspace(0, 1, Ns)]

fig, ax = plt.subplots(nrows=1, ncols=Ns, figsize=(40, 5), sharey=True)
for ix in range(Ns):
i = ix + 0
ax[ix].scatter(np.arange(N), eigen_vectors[:, i], label=f"{i}", color='b', s=4)
# Bar plot of the fourier factors
ax[ix].bar(np.arange(N), eigen_vectors[:, i], color='k', alpha=1, width=1)
ax[ix].set_title(f"State {i}")
ax[ix].set_xlim(-1.5, N+1.5)
ax[ix].set_ylim(-1.5, 1.5)
fig.tight_layout()

# ax.set_title(f"Fourier factors for k = {k}")

fig.tight_layout()
fig.savefig(f"fourier_factors_k_{k}.svg")
plt.show()

def get_wave_function_real_space(k: np.ndarray, r_min, r_max, n_points, basis_G: np.ndarray=None) -> np.ndarray:
"""Generate the wave function in real space.
Args:
k (np.ndarray): The wave vector.
r_min (float): The minimum value of the real space.
r_max (float): The maximum value of the real space.
n_points (int): The number of points in the real space.
basis_G (np.ndarray): The basis vector.
Returns:
np.ndarray: The wave function in real space.
"""
if basis_G is None:
basis_G = EPM_BASIS
eigen_values, eigen_vectors = eigen_states(k, basis_G)
# Get the wave function in real space
xyz = np.linspace(r_min, r_max, n_points)
X, Y, Z = np.meshgrid(xyz, xyz, xyz)
wave_function = np.zeros((n_points, n_points, n_points), dtype=complex)
for idx_x in range(n_points):
for idx_y in range(n_points):
for idx_z in range(n_points):
r_point = np.array([X[idx_x, idx_y, idx_z], Y[idx_x, idx_y, idx_z], Z[idx_x, idx_y, idx_z]])
for i in range(len(eigen_vectors)):
G_vect = basis_G[i]
exp_GdotR = np.exp(-1j * np.dot(basis_G, r_point))
print(f"Shape of exp_GdotR: {exp_GdotR.shape}")
SUM = np.dot(eigen_vectors, exp_GdotR)
print(f"Shape of SUM: {SUM.shape}")
for idx_G in range(len(G_vect)):
wave_function[idx_x, idx_y, idx_z] += np.exp(-1j * np.dot(G_vect, r_point)) * eigen_vectors[i, idx_G]
wave_function[idx_x, idx_y, idx_z] *= np.exp(1j * np.dot(k, r_point))
wave_function[idx_x, idx_y, idx_z] = np.abs(wave_function[idx_x, idx_y, idx_z])**2
return wave_function


def get_wave_function_real_space_vectorized(k: np.ndarray, r_min, r_max, n_points, basis_G: np.ndarray=None) -> np.ndarray:
"""Generate the wave function in real space.
Args:
k (np.ndarray): The wave vector.
r_min (float): The minimum value of the real space.
r_max (float): The maximum value of the real space.
n_points (int): The number of points in the real space.
basis_G (np.ndarray): The basis vector.
Returns:
np.ndarray: The wave function in real space.
"""
if basis_G is None:
basis_G = EPM_BASIS
eigen_values, eigen_vectors = eigen_states(k, basis_G)
print(f"Shape of eigen vectors: {eigen_vectors.shape}")
# Get the wave function in real space
xyz = np.linspace(r_min, r_max, n_points)
X, Y, Z = np.meshgrid(xyz, xyz, xyz)
wave_function = np.zeros((n_points, n_points, n_points), dtype=complex)
for i in range(len(eigen_vectors)):
G_vect = basis_G[i]
for idx_G in range(len(G_vect)):
wave_function += eigen_vectors[i, idx_G] * np.exp(-1j * np.dot(G_vect, [X, Y, Z]))
wave_function *= np.exp(1j * np.dot(k, [X, Y, Z]))
wave_function = np.abs(wave_function)**2



return wave_function.reshape(n_points, n_points, n_points)


def plot_wave_function_real_space(k: np.ndarray, r_min, r_max, n_points, basis_G: np.ndarray=None):
"""Plot the wave function in real space.
Args:
k (np.ndarray): The wave vector.
r_min (float): The minimum value of the real space.
r_max (float): The maximum value of the real space.
n_points (int): The number of points in the real space.
basis_G (np.ndarray): The basis vector.
"""
if basis_G is None:
basis_G = EPM_BASIS
wave_function = get_wave_function_real_space(k, r_min, r_max, n_points, basis_G)
print(f"Wave function in real space: {wave_function.shape}")
# Show the wave function in real space with heatmap, at different z values
ListIdxZ = [0, n_points // 4, n_points // 2, 3 * n_points // 4, n_points - 1]
fig, ax = plt.subplots(1, len(ListIdxZ), figsize=(20, 4))
for idx, idx_z in enumerate(ListIdxZ):
im = ax[idx].imshow(np.abs(wave_function[:, :, idx_z]), cmap='jet', interpolation='nearest', extent=[r_min, r_max, r_min, r_max])
ax[idx].set_title(f"z={idx_z}")
ax[idx].set_xlabel("x")
ax[idx].set_ylabel("y")
fig.colorbar(im, ax=ax.ravel().tolist())

plt.pause(4)


# Same with the fast version
wave_function = get_wave_function_real_space_vectorized(k, r_min, r_max, n_points, basis_G)
print(f"Wave function in real space: {wave_function.shape}")
# Show the wave function in real space with heatmap, at different z values
ListIdxZ = [0, n_points // 4, n_points // 2, 3 * n_points // 4, n_points - 1]
fig, ax = plt.subplots(1, len(ListIdxZ), figsize=(20, 4))
for idx, idx_z in enumerate(ListIdxZ):
im = ax[idx].imshow(np.abs(wave_function[:, :, idx_z]), cmap='jet', interpolation='nearest', extent=[r_min, r_max, r_min, r_max])
ax[idx].set_title(f"z={idx_z}")
ax[idx].set_xlabel("x")
ax[idx].set_ylabel("y")
fig.colorbar(im, ax=ax.ravel().tolist())
plt.pause(5)



def dielectric_function_mc(energy, q_vect, n_valence, n_conduction, Nk):
list_colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w', 'b', 'g', 'r', 'c', 'm', 'y', 'k', 'w']
Expand Down Expand Up @@ -315,15 +496,13 @@ def main_band_structure(n_points):
bands = band_structure(k)
bands -= max(bands[3])

plt.figure(figsize=(15, 9))

ax = plt.subplot(111)
fig, ax = plt.subplots(figsize=(10, 10))

# remove plot borders
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
# ax.spines['top'].set_visible(False)
# ax.spines['bottom'].set_visible(False)
# ax.spines['right'].set_visible(False)
# ax.spines['left'].set_visible(False)

# limit plot area to data
plt.xlim(0, len(bands))
Expand All @@ -334,36 +513,8 @@ def main_band_structure(n_points):
plt.xticks(xticks, ('$L$', '$\Lambda$', '$\Gamma$', '$\Delta$', '$X$', '$U,K$', '$\Sigma$', '$\Gamma$'), fontsize=18)
plt.yticks(fontsize=18)

# horizontal guide lines every 2.5 eV
for y in np.arange(-25, 25, 2.5):
plt.axhline(y, ls='--', lw=0.3, color='black', alpha=0.3)

# hide ticks, unnecessary with gridlines
plt.tick_params(axis='both', which='both',
top='off', bottom='off', left='off', right='off',
labelbottom='on', labelleft='on', pad=5)

plt.xlabel('k-Path', fontsize=20)
plt.ylabel('E(k) (eV)', fontsize=20)

plt.text(135, -18, 'Fig. 1. Band structure of Si.', fontsize=12)

# tableau 10 in fractional (r, g, b)
colors = 1 / 255 * np.array([
[31, 119, 180],
[255, 127, 14],
[44, 160, 44],
[214, 39, 40],
[148, 103, 189],
[140, 86, 75],
[227, 119, 194],
[127, 127, 127],
[188, 189, 34],
[23, 190, 207]
])

for band, color in zip(bands, colors):
plt.plot(band, lw=2.0, color=color)
for band in bands:
ax.plot(band, lw=2.0)

plt.show()

Expand All @@ -374,17 +525,13 @@ def main_band_structure(n_points):
n_valence = 4
n_conduction = 8
Nk = 1000
# main_band_structure(500)
# # # main_epsilon(N_k=Nk, energy=energy, q_vect=q_vect, n_valence=n_valence, n_conduction=n_conduction)
# # # convergence_Monkhorst_Pack(energy, q_vect, n_valence, n_conduction, 25)
# # # eps_mc = dielectric_function_mc(energy, q_vect, n_valence, n_conduction, Nk)
# # # print(f"epsilon_real MC = {eps_mc:.2e}")

# # # Nxyz = 40
# # # eps = dielectric_function_Monkhorst_Pack(energy, q_vect, n_valence, n_conduction, Nxyz)
# # # print(f"epsilon_real = {eps:.2e}")

# # # main_epsilon(Nxyz=60, q_vect=q_vect, n_valence=n_valence, n_conduction=n_conduction)
main_epsilon_mc(N_k=Nk, q_vect=q_vect, n_valence=n_valence, n_conduction=n_conduction)
# main_band_structure(250)
k_gamma = np.array([0.0, 0.0, 0.0])
rmin = -np.pi
rmax = np.pi
Nxyz = 20
# plot_eigen_states(k_gamma)
plot_wave_function_real_space(k_gamma, rmin, rmax, Nxyz)



Loading

0 comments on commit 9c5b75c

Please sign in to comment.