Skip to content

Commit

Permalink
Adding more unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
PietropaoloFrisoni committed Jul 24, 2024
1 parent 6a56972 commit 9d50198
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 225 deletions.
2 changes: 1 addition & 1 deletion pennylane/math/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def get_interface(*values):
# contains autograd and another interface
warnings.warn(
f"Contains tensors of types {non_numpy_scipy_interfaces}; dispatch will prioritize "
"TensorFlow, PyTorch, and Jax over Autograd. Consider replacing Autograd with vanilla NumPy.",
"TensorFlow, PyTorch, and Jax over Autograd. Consider replacing Autograd with vanilla NumPy.",
UserWarning,
)

Expand Down
15 changes: 11 additions & 4 deletions pennylane/ops/op_math/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ def cond(condition, true_fn, false_fn=None, elifs=()):
will be captured by Catalyst, the just-in-time (JIT) compiler, with the executed
branch determined at runtime. For more details, please see :func:`catalyst.cond`.
When used with `qml.capture.enabled()` equal to ``True``, this function allows
for general if-elif-else constructs. As with the JIT mode, all branches will be
captured, with the executed branch determined at runtime. Each branch can receive parameters.
However, the function cannot branch on mid-circuit measurements.
If a branch returns one or more variables, every other branch must return the same abstract values.
If a branch returns one or more operators, these will be appended to the QueuingManager.
.. note::
With the Python interpreter, support for :func:`~.cond`
Expand Down Expand Up @@ -511,15 +518,15 @@ def validate_abstract_values(
)

outvals_true = jaxpr_true.out_avals
outvals_false = jaxpr_false.out_avals if jaxpr_false is not None else []

if jaxpr_false is not None:
outvals_false = jaxpr_false.out_avals
validate_abstract_values(outvals_false, outvals_true, "false")

for idx, jaxpr_elif in enumerate(jaxpr_elifs):
outvals_elif = jaxpr_elif.out_avals
validate_abstract_values(outvals_elif, outvals_true, "elif", idx)

if outvals_false:
validate_abstract_values(outvals_false, outvals_true, "false")

# We return the abstract values of the true branch since the abstract values
# of the false and elif branches (if they exist) should be the same
return outvals_true
Expand Down
Loading

0 comments on commit 9d50198

Please sign in to comment.