diff --git a/pennylane/transforms/optimization/optimization_utils.py b/pennylane/transforms/optimization/optimization_utils.py index 9069f2b24c5..137a1009115 100644 --- a/pennylane/transforms/optimization/optimization_utils.py +++ b/pennylane/transforms/optimization/optimization_utils.py @@ -13,21 +13,6 @@ # limitations under the License. """Utility functions for circuit optimization.""" # pylint: disable=too-many-return-statements,import-outside-toplevel -from pennylane.math import ( - allclose, - arccos, - arctan2, - asarray, - cast_like, - cos, - is_abstract, - moveaxis, - requires_grad, - sin, - sqrt, - stack, - zeros_like, -) from pennylane.ops.identity import GlobalPhase from pennylane.wires import Wires @@ -60,21 +45,25 @@ def _try_no_fuse(angles1, angles2): if some angles in the input angles vanish.""" dummy_sum = angles1 + angles2 # moveaxis required for batched inputs - phi1, theta1, omega1 = moveaxis(cast_like(asarray(angles1), dummy_sum), -1, 0) - phi2, theta2, omega2 = moveaxis(cast_like(asarray(angles2), dummy_sum), -1, 0) - - if allclose(omega1 + phi2, 0.0): - return stack([phi1, theta1 + theta2, omega2]) - zero = zeros_like(phi1) + zeros_like(phi2) - if allclose(theta1, 0.0): + phi1, theta1, omega1 = qml.math.moveaxis( + qml.math.cast_like(qml.math.asarray(angles1), dummy_sum), -1, 0 + ) + phi2, theta2, omega2 = qml.math.moveaxis( + qml.math.cast_like(qml.math.asarray(angles2), dummy_sum), -1, 0 + ) + + if qml.math.allclose(omega1 + phi2, 0.0): + return qml.math.stack([phi1, theta1 + theta2, omega2]) + zero = qml.math.zeros_like(phi1) + qml.math.zeros_like(phi2) + if qml.math.allclose(theta1, 0.0): # No Y rotation in first Rot - if allclose(theta2, 0.0): + if qml.math.allclose(theta2, 0.0): # Z rotations only - return stack([phi1 + omega1 + phi2 + omega2, zero, zero]) - return stack([phi1 + omega1 + phi2, theta2, omega2]) - if allclose(theta2, 0.0): + return qml.math.stack([phi1 + omega1 + phi2 + omega2, zero, zero]) + return qml.math.stack([phi1 + omega1 + phi2, theta2, omega2]) + if qml.math.allclose(theta2, 0.0): # No Y rotation in second Rot - return stack([phi1, theta1, omega1 + phi2 + omega2]) + return qml.math.stack([phi1, theta1, omega1 + phi2 + omega2]) return None @@ -118,35 +107,38 @@ def fuse_rot_angles(angles1, angles2): """ if not ( - is_abstract(angles1) - or is_abstract(angles2) - or requires_grad(angles1) - or requires_grad(angles2) + qml.math.is_abstract(angles1) + or qml.math.is_abstract(angles2) + or qml.math.requires_grad(angles1) + or qml.math.requires_grad(angles2) ): fused_angles = _try_no_fuse(angles1, angles2) if fused_angles is not None: return fused_angles # moveaxis required for batched inputs - phi1, theta1, omega1 = moveaxis(asarray(angles1), -1, 0) - phi2, theta2, omega2 = moveaxis(asarray(angles2), -1, 0) - c1, c2, s1, s2 = cos(theta1 / 2), cos(theta2 / 2), sin(theta1 / 2), sin(theta2 / 2) + phi1, theta1, omega1 = qml.math.moveaxis(qml.math.asarray(angles1), -1, 0) + phi2, theta2, omega2 = qml.math.moveaxis(qml.math.asarray(angles2), -1, 0) + c1, c2 = qml.math.cos(theta1 / 2), qml.math.cos(theta2 / 2) + s1, s2 = qml.math.sin(theta1 / 2), qml.math.sin(theta2 / 2) - mag = sqrt(c1**2 * c2**2 + s1**2 * s2**2 - 2 * c1 * c2 * s1 * s2 * cos(omega1 + phi2)) - theta_f = 2 * arccos(mag) + mag = qml.math.sqrt( + c1**2 * c2**2 + s1**2 * s2**2 - 2 * c1 * c2 * s1 * s2 * qml.math.cos(omega1 + phi2) + ) + theta_f = 2 * qml.math.arccos(mag) alpha1, beta1 = (phi1 + omega1) / 2, (phi1 - omega1) / 2 alpha2, beta2 = (phi2 + omega2) / 2, (phi2 - omega2) / 2 - alpha_arg1 = -c1 * c2 * sin(alpha1 + alpha2) - s1 * s2 * sin(beta2 - beta1) - alpha_arg2 = c1 * c2 * cos(alpha1 + alpha2) - s1 * s2 * cos(beta2 - beta1) - alpha_f = -1 * arctan2(alpha_arg1, alpha_arg2) + alpha_arg1 = -c1 * c2 * qml.math.sin(alpha1 + alpha2) - s1 * s2 * qml.math.sin(beta2 - beta1) + alpha_arg2 = c1 * c2 * qml.math.cos(alpha1 + alpha2) - s1 * s2 * qml.math.cos(beta2 - beta1) + alpha_f = -1 * qml.math.arctan2(alpha_arg1, alpha_arg2) - beta_arg1 = -c1 * s2 * sin(alpha1 + beta2) + s1 * c2 * sin(alpha2 - beta1) - beta_arg2 = c1 * s2 * cos(alpha1 + beta2) + s1 * c2 * cos(alpha2 - beta1) - beta_f = -1 * arctan2(beta_arg1, beta_arg2) + beta_arg1 = -c1 * s2 * qml.math.sin(alpha1 + beta2) + s1 * c2 * qml.math.sin(alpha2 - beta1) + beta_arg2 = c1 * s2 * qml.math.cos(alpha1 + beta2) + s1 * c2 * qml.math.cos(alpha2 - beta1) + beta_f = -1 * qml.math.arctan2(beta_arg1, beta_arg2) - return stack([alpha_f + beta_f, theta_f, alpha_f - beta_f], axis=-1) + return qml.math.stack([alpha_f + beta_f, theta_f, alpha_f - beta_f], axis=-1) def _fuse_global_phases(operations):