Skip to content

Commit

Permalink
Add style_file option, return figures, cleanup etc
Browse files Browse the repository at this point in the history
  • Loading branch information
kavanase committed Aug 23, 2024
1 parent f8296d8 commit 534d738
Show file tree
Hide file tree
Showing 2 changed files with 11,250 additions and 11,293 deletions.
256 changes: 139 additions & 117 deletions doped/utils/displacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ def plot_site_displacements(
Whether to use ``plotly`` for plotting. Default is ``False``.
Set to ``True`` to get an interactive plot.
style_file (PathLike):
Path to matplotlib style file. if not set, will use the
``doped`` default style.
Path to ``matplotlib`` style file. if not set, will use the
``doped`` default displacements style.
Returns:
``plotly`` or ``matplotlib`` ``Figure``.
Expand Down Expand Up @@ -495,33 +495,46 @@ def calc_displacements_ellipsoid(
defect_entry: DefectEntry,
plot_ellipsoid: bool = False,
plot_anisotropy: bool = False,
use_plotly: bool = False,
quantile=0.8,
):
use_plotly: bool = False,
show_supercell: bool = True,
style_file: Optional[PathLike] = None,
) -> tuple:
"""
Calculate displacements around a defect site and fit an ellipsoid to these
displacements.
Set ``use_plotly = True`` to get an interactive ``plotly``
plot, useful for analysis!
The supercell edges are also plotted if ``show_supercell = True``
(default).
Args:
defect_entry (DefectEntry): ``DefectEntry`` object.
plot_ellipsoid (bool):
If True, plot the fitted ellipsoid in the crystal lattice.
plot_anisotropy (bool):
If True, plot the anisotropy of the ellipsoid radii.
use_plotly (bool):
Whether to use ``plotly`` for plotting. Default is ``False``.
Set to ``True`` to get an interactive plot.
quantile (float):
The quantile threshold for selecting significant displacements
(between 0 and 1). Default is 0.8.
use_plotly (bool):
Whether to use ``plotly`` for plotting. Default is ``False``.
Set to ``True`` to get an interactive plot.
show_supercell (bool):
Whether to show the supercell edges in the plot. Default is
``True``.
style_file (PathLike):
Path to ``matplotlib`` style file. if not set, will use the
``doped`` default displacements style.
Returns:
- (ellipsoid_center, ellipsoid_radii, ellipsoid_rotation):
A tuple containing the ellipsoid's center, radii, and rotation matrix,
or ``(None, None, None)`` if fitting was unsuccessful.
- ``plotly`` or ``matplotlib`` ``Figure``, if ``plot_ellipsoid = True``.
- ``plotly`` or ``matplotlib`` ``Figure``, if ``plot_anisotropy = True``.
"""
if use_plotly and not plotly_installed:
warnings.warn("Plotly not installed, using matplotlib instead")
Expand Down Expand Up @@ -587,7 +600,9 @@ def _get_minimum_volume_ellipsoid(P):

return (center, radii, rotation)

def _mpl_plot_ellipsoid(ellipsoid_center, ellipsoid_radii, ellipsoid_rotation, points, lattice_matrix):
def _mpl_plot_ellipsoid(
ellipsoid_center, ellipsoid_radii, ellipsoid_rotation, points, lattice_matrix, style_file
):
u = np.linspace(0.0, 2.0 * np.pi, 100)
v = np.linspace(0.0, np.pi, 100)

Expand All @@ -604,92 +619,95 @@ def _mpl_plot_ellipsoid(ellipsoid_center, ellipsoid_radii, ellipsoid_rotation, p
)

# Create a 3D plot
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection="3d")

# Plot the ellipsoid surface
ax.plot_surface(x, y, z, color="blue", alpha=0.2, rstride=4, cstride=4)
style_file = style_file or f"{os.path.dirname(__file__)}/displacement.mplstyle"
plt.style.use(style_file) # enforce style, as style.context currently doesn't work with jupyter
with plt.style.context(style_file):
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection="3d")

# Plot the points
ax.scatter(points[:, 0], points[:, 1], points[:, 2], color="black", s=10)
# Plot the ellipsoid surface
ax.plot_surface(x, y, z, color="blue", alpha=0.2, rstride=4, cstride=4)

# Plot the ellipsoid axes
axes = np.array(
[
[ellipsoid_radii[0], 0.0, 0.0],
[0.0, ellipsoid_radii[1], 0.0],
[0.0, 0.0, ellipsoid_radii[2]],
]
)
for i in range(len(axes)):
axes[i] = np.dot(axes[i], ellipsoid_rotation)
# Plot the points
ax.scatter(points[:, 0], points[:, 1], points[:, 2], color="black", s=10)

for p in axes:
ax.plot(
[ellipsoid_center[0], ellipsoid_center[0] + p[0]],
[ellipsoid_center[1], ellipsoid_center[1] + p[1]],
[ellipsoid_center[2], ellipsoid_center[2] + p[2]],
color="black",
linewidth=2,
# Plot the ellipsoid axes
axes = np.array(
[
[ellipsoid_radii[0], 0.0, 0.0],
[0.0, ellipsoid_radii[1], 0.0],
[0.0, 0.0, ellipsoid_radii[2]],
]
)
for i in range(len(axes)):
axes[i] = np.dot(axes[i], ellipsoid_rotation)

for p in axes:
ax.plot(
[ellipsoid_center[0], ellipsoid_center[0] + p[0]],
[ellipsoid_center[1], ellipsoid_center[1] + p[1]],
[ellipsoid_center[2], ellipsoid_center[2] + p[2]],
color="black",
linewidth=2,
)

def _plot_lattice(lattice_matrix, ax):

# Scale factor for the lattice lines
scale = 0.1

# Create lines along each lattice vector
for i in range(3):
x = [lattice_matrix[i][0] * scale * n for n in range(11)]
y = [lattice_matrix[i][1] * scale * n for n in range(11)]
z = [lattice_matrix[i][2] * scale * n for n in range(11)]
ax.plot(x, y, z, color="black", linewidth=0.5)

# Create lines for combinations of lattice vectors
for j in range(3):
if i != j:
x_comb = [
lattice_matrix[i][0] * scale * n + lattice_matrix[j][0] for n in range(11)
]
y_comb = [
lattice_matrix[i][1] * scale * n + lattice_matrix[j][1] for n in range(11)
]
z_comb = [
lattice_matrix[i][2] * scale * n + lattice_matrix[j][2] for n in range(11)
]
ax.plot(x_comb, y_comb, z_comb, color="black", linewidth=0.5)

for k in range(3):
if i != k and j != k:
x_comb3 = [
lattice_matrix[i][0] * scale * n
+ lattice_matrix[j][0]
+ lattice_matrix[k][0]
for n in range(11)
]
y_comb3 = [
lattice_matrix[i][1] * scale * n
+ lattice_matrix[j][1]
+ lattice_matrix[k][1]
for n in range(11)
]
z_comb3 = [
lattice_matrix[i][2] * scale * n
+ lattice_matrix[j][2]
+ lattice_matrix[k][2]
for n in range(11)
]
ax.plot(x_comb3, y_comb3, z_comb3, color="black", linewidth=0.5)

_plot_lattice(lattice_matrix, ax)

# Set the aspect ratio and limits
ax.set_box_aspect([1, 1, 1])
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")

plt.show()
def _plot_lattice(lattice_matrix, ax):

# Scale factor for the lattice lines
scale = 0.1

# Create lines along each lattice vector
for i in range(3):
x = [lattice_matrix[i][0] * scale * n for n in range(11)]
y = [lattice_matrix[i][1] * scale * n for n in range(11)]
z = [lattice_matrix[i][2] * scale * n for n in range(11)]
ax.plot(x, y, z, color="black")

# Create lines for combinations of lattice vectors
for j in range(3):
if i != j:
x_comb = [
lattice_matrix[i][0] * scale * n + lattice_matrix[j][0] for n in range(11)
]
y_comb = [
lattice_matrix[i][1] * scale * n + lattice_matrix[j][1] for n in range(11)
]
z_comb = [
lattice_matrix[i][2] * scale * n + lattice_matrix[j][2] for n in range(11)
]
ax.plot(x_comb, y_comb, z_comb, color="black")

for k in range(3):
if i != k and j != k:
x_comb3 = [
lattice_matrix[i][0] * scale * n
+ lattice_matrix[j][0]
+ lattice_matrix[k][0]
for n in range(11)
]
y_comb3 = [
lattice_matrix[i][1] * scale * n
+ lattice_matrix[j][1]
+ lattice_matrix[k][1]
for n in range(11)
]
z_comb3 = [
lattice_matrix[i][2] * scale * n
+ lattice_matrix[j][2]
+ lattice_matrix[k][2]
for n in range(11)
]
ax.plot(x_comb3, y_comb3, z_comb3, color="black")

if show_supercell:
_plot_lattice(lattice_matrix, ax)

ax.set_box_aspect([1, 1, 1], zoom=0.9) # set the aspect ratio and limits, and zoom out a bit
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")

return fig

def _plotly_plot_ellipsoid(
ellipsoid_center, ellipsoid_radii, ellipsoid_rotation, points, lattice_matrix
Expand Down Expand Up @@ -804,8 +822,10 @@ def _plot_lattice(lattice_matrix, fig):
)
)

_plot_lattice(lattice_matrix, fig)
fig.show()
if show_supercell:
_plot_lattice(lattice_matrix, fig)

return fig

def _mpl_plot_anisotropy(ellipsoid_radii, disp_df, threshold):
fig, axs = plt.subplots(1, 2, figsize=(14, 6))
Expand Down Expand Up @@ -848,9 +868,8 @@ def _mpl_plot_anisotropy(ellipsoid_radii, disp_df, threshold):
axs[1].set_ylim([0, 1])
axs[1].grid(False)

# Adjust layout and show plot
plt.tight_layout()
plt.show()
fig.tight_layout() # adjust layout
return fig

def _plotly_plot_anisotropy(ellipsoid_radii, disp_df, threshold):
fig = make_subplots(
Expand Down Expand Up @@ -954,8 +973,7 @@ def _plotly_plot_anisotropy(ellipsoid_radii, disp_df, threshold):
showlegend=False, # Disable legend
)

# Show the combined plot
fig.show()
return fig

def _shift_defect_site_to_center_of_the_supercell(sites_frac_coords, defect_frac_coords, bulk_sc):
"""
Expand Down Expand Up @@ -1102,6 +1120,10 @@ def _shift_defect_site_to_center_of_the_supercell(sites_frac_coords, defect_frac
]
].to_numpy()

ellipsoid_center = None
ellipsoid_radii = None
ellipsoid_rotation = None

# Only proceed if there are at least 10 points over the threshold
if points.shape[0] >= 10:
try:
Expand All @@ -1112,38 +1134,38 @@ def _shift_defect_site_to_center_of_the_supercell(sites_frac_coords, defect_frac
if plot_ellipsoid:
lattice_matrix = bulk_sc.as_dict()["lattice"]["matrix"]
if use_plotly:
_plotly_plot_ellipsoid(
ellipsoid_fig = _plotly_plot_ellipsoid(
ellipsoid_center, ellipsoid_radii, ellipsoid_rotation, points, lattice_matrix
)
else:
_mpl_plot_ellipsoid(
ellipsoid_center, ellipsoid_radii, ellipsoid_rotation, points, lattice_matrix
ellipsoid_fig = _mpl_plot_ellipsoid(
ellipsoid_center,
ellipsoid_radii,
ellipsoid_rotation,
points,
lattice_matrix,
style_file,
)

# If anisotropy plotting is enabled, plot the ellipsoid's radii anisotropy
if plot_anisotropy:
if use_plotly:
_plotly_plot_anisotropy(ellipsoid_radii, disp_df, threshold)
anisotropy_fig = _plotly_plot_anisotropy(ellipsoid_radii, disp_df, threshold)
else:
_mpl_plot_anisotropy(ellipsoid_radii, disp_df, threshold)

# Return the ellipsoid's center, radii, and rotation matrix
return (ellipsoid_center, ellipsoid_radii, ellipsoid_rotation)
anisotropy_fig = _mpl_plot_anisotropy(ellipsoid_radii, disp_df, threshold)

except np.linalg.LinAlgError:
# Handle the case where the matrix is singular and fitting fails
except np.linalg.LinAlgError: # handle the case where the matrix is singular and fitting fails
print("The matrix is singular and the system has no unique solution.")
ellipsoid_center = None
ellipsoid_radii = None
ellipsoid_rotation = None
return (ellipsoid_center, ellipsoid_radii, ellipsoid_rotation)
else:
# If there aren't enough points, suggest using a smaller quantile and return None values
print("Use smaller quantile.")
ellipsoid_center = None
ellipsoid_radii = None
ellipsoid_rotation = None
return (ellipsoid_center, ellipsoid_radii, ellipsoid_rotation)
else: # If there aren't enough points, suggest using a smaller quantile and return None values
print("Not enough points for plotting, try using a smaller quantile!")

return_tuple: tuple = (ellipsoid_center, ellipsoid_radii, ellipsoid_rotation)
if plot_ellipsoid:
return_tuple += (ellipsoid_fig,)
if plot_anisotropy:
return_tuple += (anisotropy_fig,)

return return_tuple


def _get_bulk_struct_with_defect(defect_entry) -> tuple:
Expand Down
Loading

0 comments on commit 534d738

Please sign in to comment.