diff --git a/tests/qchem/test_matrices.py b/tests/qchem/test_matrices.py index 4591bcde71e..94777f59a68 100644 --- a/tests/qchem/test_matrices.py +++ b/tests/qchem/test_matrices.py @@ -143,7 +143,7 @@ def test_gradient_overlap_matrix( r"""Test that the overlap gradients are correct.""" mol = qchem.Molecule(symbols, geometry, alpha=alpha, coeff=coeff) args = [mol.alpha, mol.coeff] - g_alpha = qml.jacobian(qchem.overlap_matrix(mol.basis_set), argnum=[0])(*args) + g_alpha = qml.jacobian(qchem.overlap_matrix(mol.basis_set), argnum=[0, 1])(*args) g_coeff = qml.jacobian(qchem.overlap_matrix(mol.basis_set), argnum=[1])(*args) assert np.allclose(g_alpha, g_alpha_ref) assert np.allclose(g_coeff, g_coeff_ref) @@ -693,6 +693,77 @@ def test_moment_matrix_jax(self): s = qchem.moment_matrix(mol.basis_set, e, idx)(*args) assert np.allclose(s, s_ref) + @pytest.mark.parametrize( + ("symbols", "geometry", "alpha", "coeff", "g_alpha_ref", "g_coeff_ref"), + [ + ( + ["H", "H"], + np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], requires_grad=False), + np.array( + [[3.42525091, 0.62391373, 0.1688554], [3.42525091, 0.62391373, 0.1688554]], + requires_grad=True, + ), + np.array( + [[0.15432897, 0.53532814, 0.44463454], [0.15432897, 0.53532814, 0.44463454]], + requires_grad=True, + ), + # Jacobian matrix contains gradient of S11, S12, S21, S22 wrt arg_1, arg_2. + np.array( + [ + [ + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [ + [-0.00043783, -0.09917143, -0.11600206], + [-0.00043783, -0.09917143, -0.11600206], + ], + ], + [ + [ + [-0.00043783, -0.09917143, -0.11600206], + [-0.00043783, -0.09917143, -0.11600206], + ], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + ], + ] + ), + np.array( + [ + [ + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [ + [-0.15627637, -0.02812029, 0.08809831], + [-0.15627637, -0.02812029, 0.08809831], + ], + ], + [ + [ + [-0.15627637, -0.02812029, 0.08809831], + [-0.15627637, -0.02812029, 0.08809831], + ], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + ], + ] + ), + ) + ], + ) + def test_gradient_overlap_matrix_jax( + self, symbols, geometry, alpha, coeff, g_alpha_ref, g_coeff_ref + ): + r"""Test that the overlap gradients are correct with jax.""" + import jax + jax.config.update("jax_enable_x64", True) + + geometry = qml.math.array(geometry, like="jax") + alpha = qml.math.array(alpha, like="jax") + coeff = qml.math.array(coeff, like="jax") + mol = qchem.Molecule(symbols, geometry, alpha=alpha, coeff=coeff) + args = [mol.coordinates, mol.coeff, mol.alpha] + g_alpha, g_coeff = jax.jacobian(qchem.overlap_matrix(mol.basis_set), argnums=[2, 1])(*args) + + assert np.allclose(g_alpha, g_alpha_ref) + assert np.allclose(g_coeff, g_coeff_ref) + def test_kinetic_matrix_jax(self): r"""Test that kinetic_matrix returns the correct matrix when using jax.""" symbols, geometry, alpha = generate_symbols_geometry_alpha()