Skip to content

Commit

Permalink
complete test suite
Browse files Browse the repository at this point in the history
  • Loading branch information
austingmhuang committed Sep 17, 2024
1 parent f1c57ce commit 450a147
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 19 deletions.
6 changes: 3 additions & 3 deletions pennylane/qchem/integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def _moment_integral(*args):
or normalize
or (
qml.math.get_interface(basis_a.params[1]) == "jax"
and qml.math.requires_grad(args[1])
and qml.math.requires_grad(args[1][0])
)
):
ca = ca * primitive_norm(basis_a.l, alpha)
Expand Down Expand Up @@ -712,7 +712,7 @@ def _kinetic_integral(*args):
or normalize
or (
qml.math.get_interface(basis_a.params[1]) == "jax"
and qml.math.requires_grad(args[1])
and qml.math.requires_grad(args[1][0])
)
):
ca = ca * primitive_norm(basis_a.l, alpha)
Expand Down Expand Up @@ -936,7 +936,7 @@ def _attraction_integral(*args):
or normalize
or (
qml.math.get_interface(basis_a.params[1]) == "jax"
and qml.math.requires_grad(args[1])
and qml.math.requires_grad(args[1][0])
)
):
ca = ca * primitive_norm(basis_a.l, alpha)
Expand Down
1 change: 0 additions & 1 deletion pennylane/qchem/matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ def kinetic(*args):
if args:
args_ab.extend([arg[i], arg[j]] for arg in args)
integral = kinetic_integral(a, b, normalize=False)(*args_ab)

o = qml.math.zeros((n, n))
o[i, j] = o[j, i] = 1.0
matrix = matrix + integral * o
Expand Down
3 changes: 2 additions & 1 deletion pennylane/qchem/observable_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def fermionic_observable(constant, one=None, two=None, cutoff=1.0e-12):
+ 0.5 * a⁺(3) a(3)
"""
coeffs = qml.math.array([])
constant = qml.math.array([constant])
if isinstance(constant, float):
constant = qml.math.array([constant])

if not qml.math.allclose(constant, 0.0):
coeffs = qml.math.concatenate((coeffs, constant))
Expand Down
2 changes: 1 addition & 1 deletion tests/qchem/test_hartree_fock.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def test_nuclear_energy_gradient_jax(self, symbols, geometry, g_ref):
),
],
)
def test_hf_energy_gradient(self, symbols, geometry, g_ref):
def test_hf_energy_gradient_jax(self, symbols, geometry, g_ref):
r"""Test that the gradient of the Hartree-Fock energy wrt differentiable parameters is
correct with jax."""
import jax
Expand Down
228 changes: 215 additions & 13 deletions tests/qchem/test_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,19 +680,6 @@ def test_overlap_matrix_jax(self):
s = qchem.overlap_matrix(mol.basis_set)(*args)
assert np.allclose(s, s_ref)

def test_moment_matrix_jax(self):
r"""Test that moment_matrix returns the correct matrix when using jax."""
symbols, _, alpha = generate_symbols_geometry_alpha()
geometry = qml.math.array([[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]], like="jax")
e = 1
idx = 0
s_ref = np.array([[0.0, 0.4627777], [0.4627777, 2.0]])

mol = qchem.Molecule(symbols, geometry, alpha=alpha)
args = [geometry, mol.coeff, alpha]
s = qchem.moment_matrix(mol.basis_set, e, idx)(*args)
assert np.allclose(s, s_ref)

@pytest.mark.parametrize(
("symbols", "geometry", "alpha", "coeff", "g_alpha_ref", "g_coeff_ref"),
[
Expand Down Expand Up @@ -765,6 +752,95 @@ def test_gradient_overlap_matrix_jax(
assert np.allclose(g_alpha, g_alpha_ref)
assert np.allclose(g_coeff, g_coeff_ref)

def test_moment_matrix_jax(self):
r"""Test that moment_matrix returns the correct matrix when using jax."""
symbols, _, alpha = generate_symbols_geometry_alpha()
geometry = qml.math.array([[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]], like="jax")
e = 1
idx = 0
s_ref = np.array([[0.0, 0.4627777], [0.4627777, 2.0]])

mol = qchem.Molecule(symbols, geometry, alpha=alpha)
args = [geometry, mol.coeff, alpha]
s = qchem.moment_matrix(mol.basis_set, e, idx)(*args)
assert np.allclose(s, s_ref)

@pytest.mark.parametrize(
("symbols", "geometry", "alpha", "coeff", "e", "idx", "g_alpha_ref", "g_coeff_ref"),
[
(
["H", "H"],
np.array([[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]], requires_grad=False),
np.array(
[[3.42525091, 0.62391373, 0.1688554], [3.42525091, 0.62391373, 0.1688554]],
requires_grad=True,
),
np.array(
[[0.15432897, 0.53532814, 0.44463454], [0.15432897, 0.53532814, 0.44463454]],
requires_grad=True,
),
1,
0,
# Jacobian matrix contains gradient of S11, S12, S21, S22 wrt arg_1, arg_2, computed
# with finite difference.
np.array(
[
[
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
[
[3.87296664e-03, -2.29246093e-01, -9.93852751e-01],
[-4.86326933e-04, -6.72924734e-02, 2.47919030e-01],
],
],
[
[
[3.87296664e-03, -2.29246093e-01, -9.93852751e-01],
[-4.86326933e-04, -6.72924734e-02, 2.47919030e-01],
],
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
],
]
),
np.array(
[
[
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
[
[-0.26160753, -0.18843804, 0.3176762],
[-0.09003791, 0.01797702, 0.00960757],
],
],
[
[
[-0.26160753, -0.18843804, 0.3176762],
[-0.09003791, 0.01797702, 0.00960757],
],
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
],
]
),
)
],
)
def test_gradient_moment_matrix_jax(
self, symbols, geometry, alpha, coeff, e, idx, g_alpha_ref, g_coeff_ref
):
r"""Test that the moment matrix gradients are correct for jax."""
import jax

jax.config.update("jax_enable_x64", True)

geometry = qml.math.array(geometry, like="jax")
alpha = qml.math.array(alpha, like="jax")
coeff = qml.math.array(coeff, like="jax")
mol = qchem.Molecule(symbols, geometry, alpha=alpha, coeff=coeff)
args = [mol.coordinates, mol.coeff, mol.alpha]
g_coeff, g_alpha = jax.jacobian(qchem.moment_matrix(mol.basis_set, e, idx), argnums=[1, 2])(
*args
)
assert np.allclose(g_alpha, g_alpha_ref)
assert np.allclose(g_coeff, g_coeff_ref)

def test_kinetic_matrix_jax(self):
r"""Test that kinetic_matrix returns the correct matrix when using jax."""
symbols, geometry, alpha = generate_symbols_geometry_alpha()
Expand All @@ -780,6 +856,77 @@ def test_kinetic_matrix_jax(self):
t = qchem.kinetic_matrix(mol.basis_set)(*args)
assert np.allclose(t, t_ref)

@pytest.mark.parametrize(
("symbols", "geometry", "alpha", "coeff", "g_alpha_ref", "g_coeff_ref"),
[
(
["H", "H"],
np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], requires_grad=False),
np.array(
[[3.42525091, 0.62391373, 0.1688554], [3.42525091, 0.62391373, 0.1688554]],
requires_grad=True,
),
np.array(
[[0.15432897, 0.53532814, 0.44463454], [0.15432897, 0.53532814, 0.44463454]],
requires_grad=True,
),
# Jacobian matrix contains gradient of T11, T12, T21, T22 wrt arg_1, arg_2.
np.array(
[
[
[[0.03263157, 0.85287851, 0.68779528], [0.0, 0.0, 0.0]],
[
[-0.00502729, 0.08211579, 0.3090185],
[-0.00502729, 0.08211579, 0.3090185],
],
],
[
[
[-0.00502729, 0.08211579, 0.3090185],
[-0.00502729, 0.08211579, 0.3090185],
],
[[0.0, 0.0, 0.0], [0.03263157, 0.85287851, 0.68779528]],
],
]
),
np.array(
[
[
[[1.824217, 0.10606991, -0.76087597], [0.0, 0.0, 0.0]],
[
[-0.00846016, 0.08488012, -0.09925695],
[-0.00846016, 0.08488012, -0.09925695],
],
],
[
[
[-0.00846016, 0.08488012, -0.09925695],
[-0.00846016, 0.08488012, -0.09925695],
],
[[0.0, 0.0, 0.0], [1.824217, 0.10606991, -0.76087597]],
],
]
),
)
],
)
def test_gradient_kinetic_matrix_jax(
self, symbols, geometry, alpha, coeff, g_alpha_ref, g_coeff_ref
):
r"""Test that the kinetic gradients are correct for jax."""
import jax

jax.config.update("jax_enable_x64", True)

geometry = qml.math.array(geometry, like="jax")
alpha = qml.math.array(alpha, like="jax")
coeff = qml.math.array(coeff, like="jax")
mol = qchem.Molecule(symbols, geometry, alpha=alpha, coeff=coeff)
args = [geometry, mol.coeff, mol.alpha]
g_alpha, g_coeff = jax.jacobian(qchem.kinetic_matrix(mol.basis_set), argnums=[2, 1])(*args)
assert np.allclose(g_alpha, g_alpha_ref)
assert np.allclose(g_coeff, g_coeff_ref)

def test_core_matrix_diff_positions_jax(self):
r"""Test that core_matrix returns the correct matrix when positions are differentiable
when using jax."""
Expand Down Expand Up @@ -831,3 +978,58 @@ def test_attraction_matrix_diffR_jax(self):
args = [geometry, mol.coeff, alpha]
v = qchem.attraction_matrix(mol.basis_set, mol.nuclear_charges, mol.coordinates)(*args)
assert np.allclose(v, v_ref)

@pytest.mark.parametrize(
("symbols", "geometry", "alpha", "coeff", "g_r_ref"),
[
(
["H", "H"],
np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], requires_grad=True),
np.array(
[[3.42525091, 0.62391373, 0.1688554], [3.42525091, 0.62391373, 0.1688554]],
requires_grad=True,
),
np.array(
[[0.15432897, 0.53532814, 0.44463454], [0.15432897, 0.53532814, 0.44463454]],
requires_grad=True,
),
np.array(
[
[
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.44900112]],
[
[0.0, 0.0, -0.26468668],
[0.0, 0.0, 0.26468668],
],
],
[
[
[0.0, 0.0, -0.26468668],
[0.0, 0.0, 0.26468668],
],
[[0.0, 0.0, -0.44900112], [0.0, 0.0, 0.0]],
],
]
),
)
],
)
def test_gradient_attraction_matrix_jax(self, symbols, geometry, alpha, coeff, g_r_ref):
r"""Test that the attraction gradients are correct for jax."""
import jax

jax.config.update("jax_enable_x64", True)

geometry = qml.math.array(geometry, like="jax")
alpha = qml.math.array(alpha, like="jax")
coeff = qml.math.array(coeff, like="jax")
mol = qchem.Molecule(symbols, geometry, alpha=alpha, coeff=coeff)
r_basis = mol.coordinates
args = [mol.coordinates, r_basis, mol.coeff, mol.alpha]

g_r, _, _, _ = jax.jacobian(
qchem.attraction_matrix(mol.basis_set, mol.nuclear_charges, mol.coordinates),
argnums=[0, 1, 2, 3],
)(*args)

assert np.allclose(g_r, g_r_ref)

0 comments on commit 450a147

Please sign in to comment.