Skip to content

Commit

Permalink
grad overlap matrix smile
Browse files Browse the repository at this point in the history
  • Loading branch information
austingmhuang committed Sep 16, 2024
1 parent a3dbda2 commit 62c97dc
Showing 1 changed file with 72 additions and 1 deletion.
73 changes: 72 additions & 1 deletion tests/qchem/test_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_gradient_overlap_matrix(
r"""Test that the overlap gradients are correct."""
mol = qchem.Molecule(symbols, geometry, alpha=alpha, coeff=coeff)
args = [mol.alpha, mol.coeff]
g_alpha = qml.jacobian(qchem.overlap_matrix(mol.basis_set), argnum=[0])(*args)
g_alpha = qml.jacobian(qchem.overlap_matrix(mol.basis_set), argnum=[0, 1])(*args)
g_coeff = qml.jacobian(qchem.overlap_matrix(mol.basis_set), argnum=[1])(*args)
assert np.allclose(g_alpha, g_alpha_ref)
assert np.allclose(g_coeff, g_coeff_ref)
Expand Down Expand Up @@ -693,6 +693,77 @@ def test_moment_matrix_jax(self):
s = qchem.moment_matrix(mol.basis_set, e, idx)(*args)
assert np.allclose(s, s_ref)

@pytest.mark.parametrize(
("symbols", "geometry", "alpha", "coeff", "g_alpha_ref", "g_coeff_ref"),
[
(
["H", "H"],
np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], requires_grad=False),
np.array(
[[3.42525091, 0.62391373, 0.1688554], [3.42525091, 0.62391373, 0.1688554]],
requires_grad=True,
),
np.array(
[[0.15432897, 0.53532814, 0.44463454], [0.15432897, 0.53532814, 0.44463454]],
requires_grad=True,
),
# Jacobian matrix contains gradient of S11, S12, S21, S22 wrt arg_1, arg_2.
np.array(
[
[
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
[
[-0.00043783, -0.09917143, -0.11600206],
[-0.00043783, -0.09917143, -0.11600206],
],
],
[
[
[-0.00043783, -0.09917143, -0.11600206],
[-0.00043783, -0.09917143, -0.11600206],
],
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
],
]
),
np.array(
[
[
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
[
[-0.15627637, -0.02812029, 0.08809831],
[-0.15627637, -0.02812029, 0.08809831],
],
],
[
[
[-0.15627637, -0.02812029, 0.08809831],
[-0.15627637, -0.02812029, 0.08809831],
],
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
],
]
),
)
],
)
def test_gradient_overlap_matrix_jax(
self, symbols, geometry, alpha, coeff, g_alpha_ref, g_coeff_ref
):
r"""Test that the overlap gradients are correct with jax."""
import jax
jax.config.update("jax_enable_x64", True)

geometry = qml.math.array(geometry, like="jax")
alpha = qml.math.array(alpha, like="jax")
coeff = qml.math.array(coeff, like="jax")
mol = qchem.Molecule(symbols, geometry, alpha=alpha, coeff=coeff)
args = [mol.coordinates, mol.coeff, mol.alpha]
g_alpha, g_coeff = jax.jacobian(qchem.overlap_matrix(mol.basis_set), argnums=[2, 1])(*args)

assert np.allclose(g_alpha, g_alpha_ref)
assert np.allclose(g_coeff, g_coeff_ref)

def test_kinetic_matrix_jax(self):
r"""Test that kinetic_matrix returns the correct matrix when using jax."""
symbols, geometry, alpha = generate_symbols_geometry_alpha()
Expand Down

0 comments on commit 62c97dc

Please sign in to comment.