Skip to content

Commit

Permalink
Excited state nodes, docs, and an example (lanl#42)
Browse files Browse the repository at this point in the history
* Add nodes for NACR and phase-less loss

* NACR node now does NACR_ij * ΔE_ij

* dE should be E_j - E_i

* remove parent expansion

* Update NACR implementation

1. Fix a bug where the charges tensor is incorrectly sliced.
2. Add corresponding tests to NACR layers.

* Fix a bug on calculating the phase-less loss

When the predicted vector is not in the same quadrant as the true value,
the calculated loss would be smaller than the correct one. This might
cause training process slower or stuck.

* Fix a bug in new MSE implementation

* Fix a bug in setting up custom kernels

Now "auto" will use pytorch when numba or cupy is not installed

* Add nodes and example for excited states

* Update changelog and rename example file

* update example for excited states

* move excited states into subdirectories

* fix import order and make excited states import by default

* update documentation

* documentation update

---------

Co-authored-by: Nicholas Lubbers <hippynn@lanl.gov>
  • Loading branch information
tautomer and Nicholas Lubbers authored Nov 15, 2023
1 parent 99aa5ae commit 2ff5e7a
Show file tree
Hide file tree
Showing 14 changed files with 622 additions and 114 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
__pycache__/
*.pyc
build/
hippynn.egg-info/*
hippynn.egg-info/*
14 changes: 14 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
0.0.2a3
=======

New Features:
-------------

- Add nodes for non-adiabatic coupling vectors (NACR) and phase-less loss.
See /examples/excited_states_azomethane.py.

Improvements
------------

- Multi-target dipole node now has a shape of (n_molecules, n_targets, 3).

0.0.2a2
=======

Expand Down
72 changes: 72 additions & 0 deletions docs/source/examples/excited_states.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
Non-Adiabiatic Excited States
=============================

`hippynn` has features for training to excited-state energies, transition dipoles, and
the non-adiabatic coupling vectors (NACR). These features can be found in
:mod:`~hippynn.graphs.nodes.excited`.

For a more detailed description, please see the paper [Li2023]_

Multi-targets nodes are recommended over the usage of one node per target.

For energies, the node can be constructed just like the ground-state
counterpart::

energy = targets.HEnergyNode("E", network, module_kwargs={"n_target": n_states + 1})
mol_energy = energy.mol_energy
mol_energy.db_name = "E"

Note that a `multi-target node` is used here, defined by the keyword
``module_kwargs={"n_target": n_states + 1}``. Here, `n_states` is the number of
*excited* states in consideration. The extra state is for the ground state, which is often
useful. The database name is simply `E` with a shape of ``(n_molecules,
n_states+1)``.

Predicting the transition dipoles is also similar to the ground-state permanent
dipole::

charge = targets.HChargeNode("Q", network, module_kwargs={"n_target": n_states})
dipole = physics.DipoleNode("D", (charge, positions), db_name="D")

The database name is `D` with a shape of ``(n_molecules, n_states, 3)``.

For NACR, to avoid singularity problems, we enforcing the training of NACR*ΔE
instead::

nacr = excited.NACRMultiStateNode(
"ScaledNACR",
(charge, positions, energy),
db_name="ScaledNACR",
module_kwargs={"n_target": n_states},
)

For NACR between state `i` and `j`, :math:`\boldsymbol{d}_{ij}`, it is expressed
in the following way

.. math::
\boldsymbol{d}_{ij}\Delta E_{ij} = \Delta E_{ij}\boldsymbol{q}_i \frac{\partial\boldsymbol{q}_j}{\partial\boldsymbol{R}}
:math:`E_{ij}` is energy difference between state `i` and `j`, which is
calculated internally in the NACR node based on the input of the ``energy``
node. :math:`\boldsymbol{R}` corresponding the ``positions`` node in the code.
:math:`\boldsymbol{q}_{i}` and :math:`\boldsymbol{q}_{j}` are the transition
atomic charges for state `i` and `j` contained in the ``charge`` node. This
charge node can be constructed from scratch or reused from the dipole
predictions. The database name is `ScaledNACR` with a shape of ``(n_molecules,
n_states*(n_states-1)/2, 3*n_atoms)``.

Due to the phase problem, when the loss function is constructed, the
`phase-less` version of MAE or RMSE should be used::

energy_mae = loss.MAELoss.of_node(energy)
dipole_mae = excited.MAEPhaseLoss.of_node(dipole)
nacr_mae = excited.MAEPhaseLoss.of_node(nacr)

:class:`~hippynn.graphs.nodes.excited.MAEPhaseLoss` and
:class:`~hippynn.graphs.nodes.excited.MSEPhaseLoss` are the `phase-less` version MAE
and MSE, which take the minimum error over the possible signs of the output.

For a complete script, please take a look at ``examples/excited_states_azomethane.py``.

.. [Li2023] | Machine Learning Framework for Modeling Exciton-Polaritons in Molecular Materials.
| Li et. al, 2023. https://arxiv.org/abs/2306.02523
1 change: 1 addition & 0 deletions docs/source/examples/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ the examples are just snippets. For fully-fledged examples see the
restarting
ase_calculator
mliap_unified
excited_states

188 changes: 188 additions & 0 deletions examples/excited_states_azomethane.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""
Example training script to predicted excited-states energies, transition dipoles, and
non-adiabatic coupling vectors (NACR)
The dataset used in this example can be found at https://doi.org/10.5281/zenodo.7076420.
This script is set up to assume the "release" folder from the zenodo record
is placed in ../../datasets/azomethane/ relative to this script.
For more information on the modeling techniques, please see the paper:
Machine Learning Framework for Modeling Exciton-Polaritons in Molecular Materials
Li, et al. (2023)
https://arxiv.org/abs/2306.02523
"""
import json

import matplotlib
import numpy as np
import torch

import hippynn
from hippynn import plotting
from hippynn.experiment import setup_training, train_model
from hippynn.experiment.controllers import PatienceController, RaiseBatchSizeOnPlateau
from hippynn.graphs import inputs, loss, networks, physics, targets, excited

matplotlib.use("Agg")
# default types for torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_default_dtype(torch.float32)

hippynn.settings.WARN_LOW_DISTANCES = False
hippynn.settings.TRANSPARENT_PLOT = True

n_atoms = 10
n_states = 3
plot_frequency = 100
dipole_weight = 4
nacr_weight = 2
l2_weight = 2e-5

# Hyperparameters for the network
# Note: These hyperparameters were generated via
# a tuning algorithm, hence their somewhat arbitrary nature.
network_params = {
"possible_species": [0, 1, 6, 7],
"n_features": 30,
"n_sensitivities": 28,
"dist_soft_min": 0.7665723566179274,
"dist_soft_max": 3.4134447177301515,
"dist_hard_max": 4.6860240434651805,
"n_interaction_layers": 3,
"n_atom_layers": 3,
}
# dump parameters to the log file
print("Network parameters\n\n", json.dumps(network_params, indent=4))

with hippynn.tools.active_directory("TEST_AZOMETHANE_MODEL"):
with hippynn.tools.log_terminal("training_log.txt", "wt"):
# build network
species = inputs.SpeciesNode(db_name="Z")
positions = inputs.PositionsNode(db_name="R")
network = networks.Hipnn("hipnn_model", (species, positions), module_kwargs=network_params)
# add energy
energy = targets.HEnergyNode("E", network, module_kwargs={"n_target": n_states + 1})
mol_energy = energy.mol_energy
mol_energy.db_name = "E"
# add dipole
charge = targets.HChargeNode("Q", network, module_kwargs={"n_target": n_states})
dipole = physics.DipoleNode("D", (charge, positions), db_name="D")
# add NACR
nacr = excited.NACRMultiStateNode(
"ScaledNACR",
(charge, positions, energy),
db_name="ScaledNACR",
module_kwargs={"n_target": n_states},
)
# set up plotter
plotter = []
for node in [mol_energy, dipole, nacr]:
plotter.append(plotting.Hist2D.compare(node, saved=True, shown=False))
for i in range(network_params["n_interaction_layers"]):
plotter.append(
plotting.SensitivityPlot(
network.torch_module.sensitivity_layers[i],
saved=f"Sensitivity_{i}.pdf",
shown=False,
)
)
plotter = plotting.PlotMaker(*plotter, plot_every=plot_frequency)
# build the loss function
validation_losses = {}
# energy
energy_rmse = loss.MSELoss.of_node(energy) ** 0.5
validation_losses["E-RMSE"] = energy_rmse
energy_mae = loss.MAELoss.of_node(energy)
validation_losses["E-MAE"] = energy_mae
energy_loss = energy_rmse + energy_mae
validation_losses["E-Loss"] = energy_loss
total_loss = energy_loss
# dipole
dipole_rmse = excited.MSEPhaseLoss.of_node(dipole) ** 0.5
validation_losses["D-RMSE"] = dipole_rmse
dipole_mae = excited.MAEPhaseLoss.of_node(dipole)
validation_losses["D-MAE"] = dipole_mae
dipole_loss = dipole_rmse / np.sqrt(3) + dipole_mae
validation_losses["D-Loss"] = dipole_loss
total_loss += dipole_weight * dipole_loss
# nacr
nacr_rmse = excited.MSEPhaseLoss.of_node(nacr) ** 0.5
validation_losses["NACR-RMSE"] = nacr_rmse
nacr_mae = excited.MAEPhaseLoss.of_node(nacr)
validation_losses["NACR-MAE"] = nacr_mae
nacr_loss = nacr_rmse / np.sqrt(3 * n_atoms) + nacr_mae
validation_losses["NACR-Loss"] = nacr_loss
total_loss += nacr_weight * nacr_loss
# l2 regularization
l2_reg = loss.l2reg(network)
validation_losses["L2"] = l2_reg
loss_regularization = l2_weight * l2_reg
# add total loss to the dictionary
validation_losses["Loss_wo_L2"] = total_loss
validation_losses["Loss"] = total_loss + loss_regularization

# set up experiment
training_modules, db_info = hippynn.experiment.assemble_for_training(
validation_losses["Loss"],
validation_losses,
plot_maker=plotter,
)
# set up the optimizer
optimizer = torch.optim.AdamW(training_modules.model.parameters(), lr=1e-3)
# use higher patience for production runs
scheduler = RaiseBatchSizeOnPlateau(optimizer=optimizer, max_batch_size=2048, patience=10, factor=0.5)
controller = PatienceController(
optimizer=optimizer,
scheduler=scheduler,
batch_size=32,
eval_batch_size=2048,
# use higher max_epochs for production runs
max_epochs=100,
stopping_key="Loss",
fraction_train_eval=0.1,
# use higher termination_patience for production runs
termination_patience=10,
)
experiment_params = hippynn.experiment.SetupParams(controller=controller)

# load database
database = hippynn.databases.DirectoryDatabase(
name="azo_", # Prefix for arrays in the directory
directory="../../../datasets/azomethane/release/training/",
seed=114514, # Random seed for splitting data
**db_info, # Adds the inputs and targets db_names from the model as things to load
)
# use 10% of the dataset just for quick testing purpose
database.make_random_split("train", 0.07)
database.make_random_split("valid", 0.02)
database.make_random_split("test", 0.01)
database.splitting_completed = True
# split the whole dataset into train, valid, test in the ratio of 7:2:1
# database.make_trainvalidtest_split(0.1, 0.2)

# set up training
training_modules, controller, metric_tracker = setup_training(
training_modules=training_modules,
setup_params=experiment_params,
)
# train model
metric_tracker = train_model(
training_modules,
database,
controller,
metric_tracker,
callbacks=None,
batch_callbacks=None,
)

del network_params["possible_species"]
network_params["metric"] = metric_tracker.best_metric_values
network_params["avg_epoch_time"] = np.average(metric_tracker.epoch_times)
network_params["Loss"] = metric_tracker.best_metric_values["valid"]["Loss"]

with open("training_summary.json", "w") as out:
json.dump(network_params, out, indent=4)
4 changes: 3 additions & 1 deletion hippynn/graphs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
from . import indextypes
from .indextypes import clear_index_cache, IdxType

from .nodes import base, inputs, networks, targets, physics, loss
from .nodes import base, inputs
from .nodes.base import find_unique_relative, find_relatives, get_connected_nodes

from .gops import get_subgraph, copy_subgraph, replace_node, compute_evaluation_order

from .nodes import networks, targets, physics, loss, excited

# Needed to populate the registry of index transformers.
# This has to happen before the indextypes package can work,
# however, we don't want the indextypes package to depend on actual
Expand Down
Loading

0 comments on commit 2ff5e7a

Please sign in to comment.