Skip to content

Commit

Permalink
New example app: visualize structure relaxation (#323)
Browse files Browse the repository at this point in the history
* add get_breadcrumb() doc str

* add crystal_toolkit/apps/examples/relaxation_trajectory.py

* add normed force mean over atoms line to relaxation plot

clean up code and styles
  • Loading branch information
janosh authored Mar 23, 2023
1 parent e10319d commit e5f91d2
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 20 deletions.
129 changes: 129 additions & 0 deletions crystal_toolkit/apps/examples/relaxation_trajectory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import sys

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from dash import Dash, dcc, html
from dash.dependencies import Input, Output
from pymatgen.core import Structure
from pymatgen.ext.matproj import MPRester

import crystal_toolkit.components as ctc
from crystal_toolkit.settings import SETTINGS

mp_id = "mp-1033715"
with MPRester(monty_decode=False) as mpr:
[task_doc] = mpr.tasks.search(task_ids=[mp_id])

steps = [
(
Structure.from_dict(step["structure"]),
step["e_fr_energy"],
np.linalg.norm(step["forces"], axis=1).mean(),
)
for calc in reversed(task_doc.calcs_reversed)
for step in calc.output["ionic_steps"]
]
assert len(steps) == 99

e_col = "Energy (eV)"
force_col = "Force (eV/Å)"
spg_col = "Spacegroup"
struct_col = "Structure"

df_traj = pd.DataFrame(steps, columns=[struct_col, e_col, force_col])
df_traj[spg_col] = df_traj[struct_col].map(Structure.get_space_group_info)


def plot_energy_and_forces(
df: pd.DataFrame,
step: int,
e_col: str,
force_col: str,
title: str,
) -> go.Figure:
"""Plot energy and forces as a function of relaxation step."""
fig = go.Figure()
# energy trace = primary y-axis
fig.add_trace(go.Scatter(x=df.index, y=df[e_col], mode="lines", name="Energy"))

# forces trace = secondary y-axis
fig.add_trace(
go.Scatter(x=df.index, y=df[force_col], mode="lines", name="Forces", yaxis="y2")
)

fig.update_layout(
template="plotly_white",
title=title,
xaxis={"title": "Relaxation Step"},
yaxis={"title": e_col},
yaxis2={"title": force_col, "overlaying": "y", "side": "right"},
legend=dict(yanchor="top", y=1, xanchor="right", x=1),
)

# vertical line at the specified step
fig.add_vline(x=step, line={"dash": "dash", "width": 1})

return fig


if "struct_comp" not in locals():
struct_comp = ctc.StructureMoleculeComponent(
id="structure", struct_or_mol=df_traj[struct_col][0]
)

step_size = max(1, len(steps) // 20) # ensure slider has max 20 steps
slider = dcc.Slider(
id="slider", min=0, max=len(steps) - 1, value=0, step=step_size, updatemode="drag"
)


def make_title(spg: tuple[str, int]) -> str:
"""Return a title for the figure."""
href = f"https://materialsproject.org/materials/{mp_id}/"
return f"<a {href=}>{mp_id}</a> - {spg[0]} ({spg[1]})"


title = make_title(df_traj[spg_col][0])
graph = dcc.Graph(
id="fig",
figure=plot_energy_and_forces(df_traj, 0, e_col, force_col, title),
style={"maxWidth": "50%"},
)

app = Dash(prevent_initial_callbacks=True, assets_folder=SETTINGS.ASSETS_PATH)
app.layout = html.Div(
[
html.H1(
"Structure Relaxation Trajectory", style=dict(margin="1em", fontSize="2em")
),
html.P("Drag slider to see structure at different relaxation steps."),
slider,
html.Div(
[struct_comp.layout(), graph],
style=dict(display="flex", gap="2em", placeContent="center"),
),
],
style=dict(margin="auto", textAlign="center", maxWidth="1000px", padding="2em"),
)

ctc.register_crystal_toolkit(app=app, layout=app.layout)


@app.callback(
Output(struct_comp.id(), "data"), Output(graph, "figure"), Input(slider, "value")
)
def update_structure(step: int) -> tuple[Structure, go.Figure]:
"""Update the structure displayed in the StructureMoleculeComponent and the
dashed vertical line in the figure when the slider is moved.
"""
title = make_title(df_traj[spg_col][step])
fig = plot_energy_and_forces(df_traj, step, e_col, force_col, title)

return df_traj[struct_col][step], fig


# https://stackoverflow.com/a/74918941
is_jupyter = "ipykernel" in sys.modules

app.run(port=8050, debug=True, use_reloader=not is_jupyter)
5 changes: 3 additions & 2 deletions crystal_toolkit/components/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ class StructureMoleculeComponent(MPComponent):

def __init__(
self,
struct_or_mol: None
| (Structure | StructureGraph | Molecule | MoleculeGraph) = None,
struct_or_mol: (
None | Structure | StructureGraph | Molecule | MoleculeGraph
) = None,
id: str = None,
className: str = "box",
scene_additions: Scene | None = None,
Expand Down
30 changes: 17 additions & 13 deletions crystal_toolkit/helpers/layouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,20 +439,24 @@ def __init__(self, *args, **kwargs) -> None:


def get_breadcrumb(parts):
"""Create a breadcrumb navigation bar.
Args:
parts (dict): Dictionary of name, link pairs.
Returns:
html.Nav: Breadcrumb navigation bar.
"""
if not parts:
return html.Div()
return html.Nav()

breadcrumbs = html.Nav(
html.Ul(
[
html.Li(
dcc.Link(name, href=link),
className=(None if idx != len(parts) - 1 else "is-active"),
)
for idx, (name, link) in enumerate(parts.items())
]
),
className="breadcrumb",
)
links = [
html.Li(
dcc.Link(name, href=link),
className="is-active" if idx == len(parts) - 1 else None,
)
for idx, (name, link) in enumerate(parts.items())
]
breadcrumbs = html.Nav(html.Ul(links), className="breadcrumb")

return breadcrumbs
4 changes: 2 additions & 2 deletions crystal_toolkit/helpers/povray_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@


def pov_write_data(input_scene_comp, fstream):
"""parse a primitive display object in crystaltoolkit and print it to POV-Ray input_scene_comp
fstream.
"""Parse a primitive display object in crystaltoolkit and print it to POV-Ray
input_scene_comp fstream.
"""
vect = "{:.4f},{:.4f},{:.4f}"

Expand Down
16 changes: 13 additions & 3 deletions crystal_toolkit/renderables/volumetric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

from typing import Any

import numpy as np
from numpy.typing import ArrayLike
from pymatgen.io.vasp import VolumetricData

from crystal_toolkit.core.scene import Scene, Surface
Expand All @@ -9,15 +12,22 @@


def get_isosurface_scene(
self, data_key="total", isolvl=0.05, step_size=4, origin=None, **kwargs
):
self,
data_key: str = "total",
isolvl: float = 0.05,
step_size: int = 4,
origin: ArrayLike = None,
**kwargs: Any,
) -> Scene:
"""Get the isosurface from a VolumetricData object.
Args:
data_key (str, optional): Use the volumetric data from self.data[data_key]. Defaults to 'total'.
isolvl (float, optional): The cutoff for the isosurface to using the same units as VESTA so
e/bohr and kept grid size independent
e/bohr and kept grid size independent
step_size (int, optional): step_size parameter for marching_cubes_lewiner. Defaults to 3.
origin (ArrayLike, optional): The origin of the isosurface. Defaults to None.
**kwargs: Passed to the Surface object.
Returns:
Scene: object containing the isosurface component
Expand Down

0 comments on commit e5f91d2

Please sign in to comment.