Skip to content

Commit

Permalink
PASSING TESTS
Browse files Browse the repository at this point in the history
  • Loading branch information
austingmhuang committed Sep 17, 2024
1 parent 450a147 commit b0b8166
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 61 deletions.
6 changes: 5 additions & 1 deletion pennylane/qchem/hartree_fock.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,11 @@ def _nuclear_energy(*args):
for i, r1 in enumerate(coor):
for j, r2 in enumerate(coor[i + 1 :]):
e = e + (charges[i] * charges[i + j + 1] / qml.math.linalg.norm(r1 - r2))
return e[0]

if qml.math.get_interface(r) == "jax" and qml.math.requires_grad(args[0]):
return e[0]

return e

return _nuclear_energy

Expand Down
11 changes: 5 additions & 6 deletions pennylane/qchem/integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def _overlap_integral(*args):
or normalize
or (
qml.math.get_interface(basis_a.params[1]) == "jax"
and qml.math.requires_grad(args[1][0])
and qml.math.requires_grad(args[1])
)
):
ca = ca * primitive_norm(basis_a.l, alpha)
Expand Down Expand Up @@ -534,7 +534,7 @@ def _moment_integral(*args):
or normalize
or (
qml.math.get_interface(basis_a.params[1]) == "jax"
and qml.math.requires_grad(args[1][0])
and qml.math.requires_grad(args[1])
)
):
ca = ca * primitive_norm(basis_a.l, alpha)
Expand Down Expand Up @@ -712,7 +712,7 @@ def _kinetic_integral(*args):
or normalize
or (
qml.math.get_interface(basis_a.params[1]) == "jax"
and qml.math.requires_grad(args[1][0])
and qml.math.requires_grad(args[1])
)
):
ca = ca * primitive_norm(basis_a.l, alpha)
Expand Down Expand Up @@ -930,13 +930,12 @@ def _attraction_integral(*args):

alpha, ca, ra = _generate_params(basis_a.params, args_a)
beta, cb, rb = _generate_params(basis_b.params, args_b)

if (
getattr(basis_a.params[1], "requires_grad", False)
or normalize
or (
qml.math.get_interface(basis_a.params[1]) == "jax"
and qml.math.requires_grad(args[1][0])
and qml.math.requires_grad(args[1])
)
):
ca = ca * primitive_norm(basis_a.l, alpha)
Expand Down Expand Up @@ -1091,7 +1090,7 @@ def _repulsion_integral(*args):
or normalize
or (
qml.math.get_interface(basis_a.params[1]) == "jax"
and qml.math.requires_grad(args[1][0])
and qml.math.requires_grad(args[1])
)
):
ca = ca * primitive_norm(basis_a.l, alpha)
Expand Down
57 changes: 49 additions & 8 deletions pennylane/qchem/matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,14 @@ def overlap(*args):
for (i, a), (j, b) in it.combinations(enumerate(basis_functions), r=2):
args_ab = []
if args:
args_ab.extend([arg[i], arg[j]] for arg in args)
args_ab.extend(
(
qml.math.array([arg[i], arg[j]], like="jax")
if qml.math.get_interface(arg) == "jax"
else [arg[i], arg[j]]
)
for arg in args
)
integral = overlap_integral(a, b, normalize=False)(*args_ab)

o = qml.math.zeros((n, n))
Expand Down Expand Up @@ -148,7 +155,14 @@ def _moment_matrix(*args):
for (i, a), (j, b) in it.combinations_with_replacement(enumerate(basis_functions), r=2):
args_ab = []
if args:
args_ab.extend([arg[i], arg[j]] for arg in args)
args_ab.extend(
(
qml.math.array([arg[i], arg[j]], like="jax")
if qml.math.get_interface(arg) == "jax"
else [arg[i], arg[j]]
)
for arg in args
)
integral = moment_integral(a, b, order, idx, normalize=False)(*args_ab)

o = qml.math.zeros((n, n))
Expand Down Expand Up @@ -196,7 +210,14 @@ def kinetic(*args):
for (i, a), (j, b) in it.combinations_with_replacement(enumerate(basis_functions), r=2):
args_ab = []
if args:
args_ab.extend([arg[i], arg[j]] for arg in args)
args_ab.extend(
(
qml.math.array([arg[i], arg[j]], like="jax")
if qml.math.get_interface(arg) == "jax"
else [arg[i], arg[j]]
)
for arg in args
)
integral = kinetic_integral(a, b, normalize=False)(*args_ab)
o = qml.math.zeros((n, n))
o[i, j] = o[j, i] = 1.0
Expand Down Expand Up @@ -246,14 +267,27 @@ def attraction(*args):
integral = 0
if args:
args_ab = []

if getattr(r, "requires_grad", False) or (
qml.math.get_interface(r) == "jax" and qml.math.requires_grad(args[0])
):
args_ab.extend([arg[i], arg[j]] for arg in args[1:])
else:
args_ab.extend([arg[i], arg[j]] for arg in args)
args_ab.extend(
(
qml.math.array([arg[i], arg[j]], like="jax")
if qml.math.get_interface(arg) == "jax"
else [arg[i], arg[j]]
)
for arg in args[1:]
)

else:
args_ab.extend(
(
qml.math.array([arg[i], arg[j]], like="jax")
if qml.math.get_interface(arg) == "jax"
else [arg[i], arg[j]]
)
for arg in args
)
for k, c in enumerate(r):
if getattr(c, "requires_grad", False) or (
qml.math.get_interface(r) == "jax" and qml.math.requires_grad(args[0])
Expand Down Expand Up @@ -326,7 +360,14 @@ def repulsion(*args):
if qml.math.isnan(e_calc[(i, j, k, l)]):
args_abcd = []
if args:
args_abcd.extend([arg[i], arg[j], arg[k], arg[l]] for arg in args)
args_abcd.extend(
(
qml.math.array([arg[i], arg[j], arg[k], arg[l]], like="jax")
if qml.math.get_interface(arg) == "jax"
else [arg[i], arg[j], arg[k], arg[l]]
)
for arg in args
)
integral = repulsion_integral(a, b, c, d, normalize=False)(*args_abcd)

permutations = [
Expand Down
4 changes: 4 additions & 0 deletions pennylane/qchem/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def __init__(
)
for i in self.basis_data
]
if use_jax:
alpha = qml.math.array(alpha)

if coeff is None:
coeff = [
Expand All @@ -159,6 +161,8 @@ def __init__(
)
for i, c in enumerate(coeff)
]
if use_jax:
coeff = qml.math.array(coeff)

r = list(
itertools.chain(
Expand Down
2 changes: 0 additions & 2 deletions pennylane/qchem/observable_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def fermionic_observable(constant, one=None, two=None, cutoff=1.0e-12):
+ 0.5 * a⁺(3) a(3)
"""
coeffs = qml.math.array([])
if isinstance(constant, float):
constant = qml.math.array([constant])

if not qml.math.allclose(constant, 0.0):
coeffs = qml.math.concatenate((coeffs, constant))
Expand Down
17 changes: 1 addition & 16 deletions tests/qchem/test_hamiltonians.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,19 +427,4 @@ def circuit(*args):

grad_jax = jax.grad(energy(mol), argnums=2)(*args)

alpha_1 = qml.math.array(
[[3.42425091, 0.62391373, 0.1688554], [3.42525091, 0.62391373, 0.1688554]],
like="jax",
) # alpha[0][0] -= 0.001

alpha_2 = qml.math.array(
[[3.42625091, 0.62391373, 0.1688554], [3.42525091, 0.62391373, 0.1688554]],
like="jax",
) # alpha[0][0] += 0.001

e_1 = energy(mol)(geometry, mol.coeff, alpha_1)
e_2 = energy(mol)(geometry, mol.coeff, alpha_2)

grad_finitediff = (e_2 - e_1) / 0.002

assert np.allclose(grad_jax[0][0], grad_finitediff, rtol=1e-02)
assert np.allclose(grad_jax[0][0], 0.02461335393055819, rtol=1e-02)
27 changes: 2 additions & 25 deletions tests/qchem/test_hartree_fock.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,29 +278,6 @@ def create_jax_like_array(values):

@pytest.mark.jax
class TestJax:
@pytest.mark.parametrize(
("symbols", "geometry", "g_ref"),
[
(
["H", "H"],
[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]],
# HF gradient computed with pyscf using rnuc_grad_method().kernel()
[[0.0, 0.0, 0.3650435], [0.0, 0.0, -0.3650435]],
),
],
)
def test_hf_energy_gradient(self, symbols, geometry, g_ref):
r"""Test that the gradient of the Hartree-Fock energy wrt differentiable parameters is
correct."""
import jax

geometry = jax.numpy.array(geometry)
mol = qchem.Molecule(symbols, geometry)
args = [geometry, mol.coeff, mol.alpha]
g = jax.grad(qchem.hf_energy(mol), argnums=0)(*args)
g_ref = jax.numpy.array(g_ref)
assert np.allclose(g, g_ref)

@pytest.mark.parametrize(
("symbols", "geometry", "e_ref"),
[
Expand All @@ -322,7 +299,7 @@ def test_hf_energy_gradient(self, symbols, geometry, g_ref):
),
],
)
def test_nuclear_energy(self, symbols, geometry, e_ref):
def test_nuclear_energy_jax(self, symbols, geometry, e_ref):
r"""Test that nuclear_energy returns the correct energy when using jax."""
geometry = create_jax_like_array(geometry)
mol = qchem.Molecule(symbols, geometry)
Expand Down Expand Up @@ -382,6 +359,6 @@ def test_hf_energy_gradient_jax(self, symbols, geometry, g_ref):

mol = qchem.Molecule(symbols, geometry)
args = [geometry, mol.coeff, mol.alpha]
g = jax.grad(qchem.hf_energy(mol), argnums=0)(*args)
g = jax.grad(qchem.hf_energy(mol), argnums=[0])(*args)

assert np.allclose(g, g_ref)
5 changes: 2 additions & 3 deletions tests/qchem/test_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_gradient_overlap_matrix(
r"""Test that the overlap gradients are correct."""
mol = qchem.Molecule(symbols, geometry, alpha=alpha, coeff=coeff)
args = [mol.alpha, mol.coeff]
g_alpha = qml.jacobian(qchem.overlap_matrix(mol.basis_set), argnum=[0, 1])(*args)
g_alpha = qml.jacobian(qchem.overlap_matrix(mol.basis_set), argnum=[0])(*args)
g_coeff = qml.jacobian(qchem.overlap_matrix(mol.basis_set), argnum=[1])(*args)
assert np.allclose(g_alpha, g_alpha_ref)
assert np.allclose(g_coeff, g_coeff_ref)
Expand Down Expand Up @@ -1024,8 +1024,7 @@ def test_gradient_attraction_matrix_jax(self, symbols, geometry, alpha, coeff, g
alpha = qml.math.array(alpha, like="jax")
coeff = qml.math.array(coeff, like="jax")
mol = qchem.Molecule(symbols, geometry, alpha=alpha, coeff=coeff)
r_basis = mol.coordinates
args = [mol.coordinates, r_basis, mol.coeff, mol.alpha]
args = [mol.coordinates, mol.coordinates, mol.coeff, mol.alpha]

g_r, _, _, _ = jax.jacobian(
qchem.attraction_matrix(mol.basis_set, mol.nuclear_charges, mol.coordinates),
Expand Down

0 comments on commit b0b8166

Please sign in to comment.