Skip to content

Commit

Permalink
PAINFUL ORDERING
Browse files Browse the repository at this point in the history
  • Loading branch information
austingmhuang committed Sep 13, 2024
1 parent 110b641 commit 8ab08eb
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 16 deletions.
6 changes: 3 additions & 3 deletions pennylane/qchem/hartree_fock.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ def _scf(*args):
if qml.math.get_interface(r) == "jax" and qml.math.requires_grad(args[0]):
args_r = [[args[0][i]] * mol.n_basis[i] for i in range(len(mol.n_basis))]
args_ = [*args] + [qml.math.vstack(list(itertools.chain(*args_r)))]
rep_tensor = repulsion_tensor(basis_functions)(args_[2], args_[1], args_[3])
s = overlap_matrix(basis_functions)(args_[2], args_[1], args_[3])
rep_tensor = repulsion_tensor(basis_functions)(args_[3], args_[1], args_[2])
s = overlap_matrix(basis_functions)(args_[3], args_[1], args_[2])
h_core = core_matrix(basis_functions, charges, r)(
args_[0], args_[2], args_[1], args_[3]
args_[0], args_[3], args_[1], args_[2]
)
else:
rep_tensor = repulsion_tensor(basis_functions)(*args)
Expand Down
49 changes: 36 additions & 13 deletions pennylane/qchem/integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,17 +312,22 @@ def _overlap_integral(*args):
Returns:
array[float]: the overlap integral between two contracted Gaussian orbitals
"""
args_a = [arg[0] for arg in args]
args_b = [arg[1] for arg in args]
alpha, ca, ra = _generate_params(basis_a.params, args_a)
args_a = [arg[0] for arg in args] # in autograd: coeff, alpha, coord
args_b = [arg[1] for arg in args] # in jax: alpha, coeff, coord

alpha, ca, ra = _generate_params(basis_a.params, args_a) # cca
beta, cb, rb = _generate_params(basis_b.params, args_b)

if qml.math.get_interface(basis_a.params[1]) == "jax":
alpha, ra, ca = _generate_params(basis_a.params, args_a)
beta, rb, cb = _generate_params(basis_b.params, args_b)

if (
getattr(basis_a.params[1], "requires_grad", False)
or normalize
or (
qml.math.get_interface(basis_a.params[1]) == "jax"
and qml.math.requires_grad(args[1])
and qml.math.requires_grad(args[1][0])
)
):
ca = ca * primitive_norm(basis_a.l, alpha)
Expand Down Expand Up @@ -517,15 +522,19 @@ def _moment_integral(*args):
la = basis_a.l
lb = basis_b.l

alpha, ca, ra = _generate_params(basis_a.params, args_a)
alpha, ca, ra = _generate_params(basis_a.params, args_a) # cca
beta, cb, rb = _generate_params(basis_b.params, args_b)

if qml.math.get_interface(basis_a.params[1]) == "jax":
alpha, ra, ca = _generate_params(basis_a.params, args_a)
beta, rb, cb = _generate_params(basis_b.params, args_b)

if (
getattr(basis_a.params[1], "requires_grad", False)
or normalize
or (
qml.math.get_interface(basis_a.params[1]) == "jax"
and qml.math.requires_grad(args[1])
and qml.math.requires_grad(args[1][0])
)
):
ca = ca * primitive_norm(basis_a.l, alpha)
Expand Down Expand Up @@ -695,15 +704,20 @@ def _kinetic_integral(*args):
"""
args_a = [arg[0] for arg in args]
args_b = [arg[1] for arg in args]
alpha, ca, ra = _generate_params(basis_a.params, args_a)

alpha, ca, ra = _generate_params(basis_a.params, args_a) # cca
beta, cb, rb = _generate_params(basis_b.params, args_b)

if qml.math.get_interface(basis_a.params[1]) == "jax":
alpha, ra, ca = _generate_params(basis_a.params, args_a)
beta, rb, cb = _generate_params(basis_b.params, args_b)

if (
getattr(basis_a.params[1], "requires_grad", False)
or normalize
or (
qml.math.get_interface(basis_a.params[1]) == "jax"
and qml.math.requires_grad(args[1])
and qml.math.requires_grad(args[1][0])
)
):
ca = ca * primitive_norm(basis_a.l, alpha)
Expand Down Expand Up @@ -909,7 +923,7 @@ def _attraction_integral(*args):
array[float]: the electron-nuclear attraction integral
"""
if getattr(r, "requires_grad", False) or (
qml.math.get_interface(r) == "jax" and qml.math.requires_grad(args[0])
qml.math.get_interface(r) == "jax" and qml.math.requires_grad(args[0][0])
):
coor = args[0]
args_a = [arg[0] for arg in args[1:]]
Expand All @@ -919,15 +933,18 @@ def _attraction_integral(*args):
args_a = [arg[0] for arg in args]
args_b = [arg[1] for arg in args]

alpha, ca, ra = _generate_params(basis_a.params, args_a)
alpha, ca, ra = _generate_params(basis_a.params, args_a) # cca
beta, cb, rb = _generate_params(basis_b.params, args_b)

if qml.math.get_interface(basis_a.params[1]) == "jax":
alpha, ra, ca = _generate_params(basis_a.params, args_a)
beta, rb, cb = _generate_params(basis_b.params, args_b)
if (
getattr(basis_a.params[1], "requires_grad", False)
or normalize
or (
qml.math.get_interface(basis_a.params[1]) == "jax"
and qml.math.requires_grad(args[1])
and qml.math.requires_grad(args[1][0])
)
):
ca = ca * primitive_norm(basis_a.l, alpha)
Expand Down Expand Up @@ -1072,17 +1089,23 @@ def _repulsion_integral(*args):
args_c = [arg[2] for arg in args]
args_d = [arg[3] for arg in args]

alpha, ca, ra = _generate_params(basis_a.params, args_a)
alpha, ca, ra = _generate_params(basis_a.params, args_a) # cca
beta, cb, rb = _generate_params(basis_b.params, args_b)
gamma, cc, rc = _generate_params(basis_c.params, args_c)
delta, cd, rd = _generate_params(basis_d.params, args_d)

if qml.math.get_interface(basis_a.params[1]) == "jax":
alpha, ra, ca = _generate_params(basis_a.params, args_a)
beta, rb, cb = _generate_params(basis_b.params, args_b)
gamma, rc, cc = _generate_params(basis_c.params, args_c)
delta, rd, cd = _generate_params(basis_d.params, args_d)

if (
getattr(basis_a.params[1], "requires_grad", False)
or normalize
or (
qml.math.get_interface(basis_a.params[1]) == "jax"
and qml.math.requires_grad(args[1])
and qml.math.requires_grad(args[1][0])
)
):
ca = ca * primitive_norm(basis_a.l, alpha)
Expand Down

0 comments on commit 8ab08eb

Please sign in to comment.