Skip to content

Commit

Permalink
revert cause it's not actually fully correct
Browse files Browse the repository at this point in the history
  • Loading branch information
austingmhuang committed Sep 13, 2024
1 parent 8ab08eb commit eaca392
Showing 1 changed file with 9 additions and 31 deletions.
40 changes: 9 additions & 31 deletions pennylane/qchem/integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,6 @@ def _overlap_integral(*args):
alpha, ca, ra = _generate_params(basis_a.params, args_a) # cca
beta, cb, rb = _generate_params(basis_b.params, args_b)

if qml.math.get_interface(basis_a.params[1]) == "jax":
alpha, ra, ca = _generate_params(basis_a.params, args_a)
beta, rb, cb = _generate_params(basis_b.params, args_b)

if (
getattr(basis_a.params[1], "requires_grad", False)
or normalize
Expand Down Expand Up @@ -522,19 +518,15 @@ def _moment_integral(*args):
la = basis_a.l
lb = basis_b.l

alpha, ca, ra = _generate_params(basis_a.params, args_a) # cca
alpha, ca, ra = _generate_params(basis_a.params, args_a)
beta, cb, rb = _generate_params(basis_b.params, args_b)

if qml.math.get_interface(basis_a.params[1]) == "jax":
alpha, ra, ca = _generate_params(basis_a.params, args_a)
beta, rb, cb = _generate_params(basis_b.params, args_b)

if (
getattr(basis_a.params[1], "requires_grad", False)
or normalize
or (
qml.math.get_interface(basis_a.params[1]) == "jax"
and qml.math.requires_grad(args[1][0])
and qml.math.requires_grad(args[1])
)
):
ca = ca * primitive_norm(basis_a.l, alpha)
Expand Down Expand Up @@ -704,20 +696,15 @@ def _kinetic_integral(*args):
"""
args_a = [arg[0] for arg in args]
args_b = [arg[1] for arg in args]

alpha, ca, ra = _generate_params(basis_a.params, args_a) # cca
alpha, ca, ra = _generate_params(basis_a.params, args_a)
beta, cb, rb = _generate_params(basis_b.params, args_b)

if qml.math.get_interface(basis_a.params[1]) == "jax":
alpha, ra, ca = _generate_params(basis_a.params, args_a)
beta, rb, cb = _generate_params(basis_b.params, args_b)

if (
getattr(basis_a.params[1], "requires_grad", False)
or normalize
or (
qml.math.get_interface(basis_a.params[1]) == "jax"
and qml.math.requires_grad(args[1][0])
and qml.math.requires_grad(args[1])
)
):
ca = ca * primitive_norm(basis_a.l, alpha)
Expand Down Expand Up @@ -923,7 +910,7 @@ def _attraction_integral(*args):
array[float]: the electron-nuclear attraction integral
"""
if getattr(r, "requires_grad", False) or (
qml.math.get_interface(r) == "jax" and qml.math.requires_grad(args[0][0])
qml.math.get_interface(r) == "jax" and qml.math.requires_grad(args[0])
):
coor = args[0]
args_a = [arg[0] for arg in args[1:]]
Expand All @@ -933,18 +920,15 @@ def _attraction_integral(*args):
args_a = [arg[0] for arg in args]
args_b = [arg[1] for arg in args]

alpha, ca, ra = _generate_params(basis_a.params, args_a) # cca
alpha, ca, ra = _generate_params(basis_a.params, args_a)
beta, cb, rb = _generate_params(basis_b.params, args_b)

if qml.math.get_interface(basis_a.params[1]) == "jax":
alpha, ra, ca = _generate_params(basis_a.params, args_a)
beta, rb, cb = _generate_params(basis_b.params, args_b)
if (
getattr(basis_a.params[1], "requires_grad", False)
or normalize
or (
qml.math.get_interface(basis_a.params[1]) == "jax"
and qml.math.requires_grad(args[1][0])
and qml.math.requires_grad(args[1])
)
):
ca = ca * primitive_norm(basis_a.l, alpha)
Expand Down Expand Up @@ -1089,23 +1073,17 @@ def _repulsion_integral(*args):
args_c = [arg[2] for arg in args]
args_d = [arg[3] for arg in args]

alpha, ca, ra = _generate_params(basis_a.params, args_a) # cca
alpha, ca, ra = _generate_params(basis_a.params, args_a)
beta, cb, rb = _generate_params(basis_b.params, args_b)
gamma, cc, rc = _generate_params(basis_c.params, args_c)
delta, cd, rd = _generate_params(basis_d.params, args_d)

if qml.math.get_interface(basis_a.params[1]) == "jax":
alpha, ra, ca = _generate_params(basis_a.params, args_a)
beta, rb, cb = _generate_params(basis_b.params, args_b)
gamma, rc, cc = _generate_params(basis_c.params, args_c)
delta, rd, cd = _generate_params(basis_d.params, args_d)

if (
getattr(basis_a.params[1], "requires_grad", False)
or normalize
or (
qml.math.get_interface(basis_a.params[1]) == "jax"
and qml.math.requires_grad(args[1][0])
and qml.math.requires_grad(args[1])
)
):
ca = ca * primitive_norm(basis_a.l, alpha)
Expand Down

0 comments on commit eaca392

Please sign in to comment.