Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Conditional a SymbolicOp #5772

Merged
merged 12 commits into from
Jun 5, 2024
Merged
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@
`par_info`, `obs_sharing_wires`, and `obs_sharing_wires_id` are now public attributes.
[(#5696)](https://github.com/PennyLaneAI/pennylane/pull/5696)

* `Conditional` now inherits from `SymbolicOp`, thus it inherits several useful common functionalities. Other properties such as adjoint and diagonalizing gates have been added using the `base` properties.
EmilianoG-byte marked this conversation as resolved.
Show resolved Hide resolved
[(##5772)](https://github.com/PennyLaneAI/pennylane/pull/5772)

* The `qml.qchem.Molecule` object is now the central object used by all qchem functions.
[(#5571)](https://github.com/PennyLaneAI/pennylane/pull/5571)

Expand Down
4 changes: 2 additions & 2 deletions pennylane/devices/qubit/apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def apply_conditional(
return cond(
op.meas_val.concretize(mid_measurements),
lambda x: apply_operation(
op.then_op,
op.base,
x,
is_state_batched=is_state_batched,
debugger=debugger,
Expand All @@ -261,7 +261,7 @@ def apply_conditional(
)
if op.meas_val.concretize(mid_measurements):
return apply_operation(
op.then_op,
op.base,
state,
is_state_batched=is_state_batched,
debugger=debugger,
Expand Down
2 changes: 1 addition & 1 deletion pennylane/drawer/drawable_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _get_op_occupied_wires(op, wire_map, bit_map):
return {mapped_wire}

if isinstance(op, Conditional):
mapped_wires = [wire_map[wire] for wire in op.then_op.wires]
mapped_wires = [wire_map[wire] for wire in op.base.wires]
min_wire = min(mapped_wires)
max_wire = max(wire_map.values())
return set(range(min_wire, max_wire + 1))
Expand Down
2 changes: 1 addition & 1 deletion pennylane/drawer/tape_mpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def _(op: qml.ops.op_math.Conditional, drawer, layer, config) -> None:
drawer.box_gate(
layer,
list(op.wires),
op.then_op.label(decimals=config.decimals),
op.base.label(decimals=config.decimals),
box_options={"zorder": 4},
text_options={"zorder": 5},
)
Expand Down
2 changes: 1 addition & 1 deletion pennylane/drawer/tape_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _add_op(op, layer_str, config):
"""Updates ``layer_str`` with ``op`` operation."""
if isinstance(op, qml.ops.Conditional): # pylint: disable=no-member
layer_str = _add_cond_grouping_symbols(op, layer_str, config)
return _add_op(op.then_op, layer_str, config)
return _add_op(op.base, layer_str, config)

if isinstance(op, MidMeasureMP):
return _add_mid_measure_op(op, layer_str, config)
Expand Down
2 changes: 1 addition & 1 deletion pennylane/ops/functions/bind_new_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def bind_new_parameters_tensor(op: Tensor, params: Sequence[TensorLike]):

@bind_new_parameters.register
def bind_new_parameters_conditional(op: qml.ops.Conditional, params: Sequence[TensorLike]):
then_op = bind_new_parameters(op.then_op, params)
then_op = bind_new_parameters(op.base, params)
mv = copy.deepcopy(op.meas_val)

return qml.ops.Conditional(mv, then_op)
20 changes: 18 additions & 2 deletions pennylane/ops/functions/equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pennylane.ops import (
Adjoint,
CompositeOp,
Conditional,
Controlled,
Exp,
Hamiltonian,
Expand Down Expand Up @@ -440,6 +441,22 @@ def _equal_adjoint(op1: Adjoint, op2: Adjoint, **kwargs):
return qml.equal(op1.base, op2.base, **kwargs)


@_equal.register
def _equal_conditional(op1: Conditional, op2: Conditional, **kwargs):
"""Determine whether two Conditional objects are equal"""
# first line of top-level equal function already confirms both are Conditionaly - only need to compare bases and meas_val
return qml.equal(op1.base, op2.base, **kwargs) and qml.equal(
op1.meas_val, op2.meas_val, **kwargs
)


@_equal.register
# pylint: disable=unused-argument
def _equal_measurement_value(op1: MeasurementValue, op2: MeasurementValue, **kwargs):
"""Determine whether two MeasurementValue objects are equal"""
return op1.measurements == op2.measurements


@_equal.register
# pylint: disable=unused-argument
def _equal_exp(op1: Exp, op2: Exp, **kwargs):
Expand Down Expand Up @@ -563,9 +580,8 @@ def _equal_measurements(
)

if op1.mv is not None and op2.mv is not None:
# qml.equal doesn't check if the MeasurementValues have the same processing functions
if isinstance(op1.mv, MeasurementValue) and isinstance(op2.mv, MeasurementValue):
return op1.mv.measurements == op2.mv.measurements
return qml.equal(op1.mv, op2.mv)

if isinstance(op1.mv, Iterable) and isinstance(op2.mv, Iterable):
if len(op1.mv) == len(op2.mv):
Expand Down
43 changes: 35 additions & 8 deletions pennylane/ops/op_math/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
from pennylane import QueuingManager
from pennylane.compiler import compiler
from pennylane.operation import AnyWires, Operation, Operator
from pennylane.ops.op_math.symbolicop import SymbolicOp
from pennylane.tape import make_qscript


class ConditionalTransformError(ValueError):
"""Error for using qml.cond incorrectly"""


class Conditional(Operation):
class Conditional(SymbolicOp):
"""A Conditional Operation.

Unless you are a Pennylane plugin developer, **you should NOT directly use this class**,
Expand All @@ -50,26 +51,52 @@ class Conditional(Operation):
num_wires = AnyWires

def __init__(self, expr, then_op: Type[Operation], id=None):
self.meas_val = expr
self.then_op = then_op
super().__init__(*then_op.data, wires=then_op.wires, id=id)
self.hyperparameters["meas_val"] = expr
self._name = f"Conditional({then_op.name})"
super().__init__(then_op, id=id)

def label(self, decimals=None, base_label=None, cache=None):
return self.then_op.label(decimals=decimals, base_label=base_label, cache=cache)
return self.base.label(decimals=decimals, base_label=base_label, cache=cache)

@property
def meas_val(self):
"the measurement outcome value to consider from `expr` argument"
return self.hyperparameters["meas_val"]

@property
def num_params(self):
return self.then_op.num_params
return self.base.num_params

@property
def ndim_params(self):
return self.then_op.ndim_params
return self.base.ndim_params

def map_wires(self, wire_map):
meas_val = self.meas_val.map_wires(wire_map)
then_op = self.then_op.map_wires(wire_map)
then_op = self.base.map_wires(wire_map)
return Conditional(meas_val, then_op=then_op)

def matrix(self, wire_order=None):
return self.base.matrix(wire_order=wire_order)

# pylint: disable=arguments-renamed, invalid-overridden-method
@property
def has_diagonalizing_gates(self):
return self.base.has_diagonalizing_gates

def diagonalizing_gates(self):
return self.base.diagonalizing_gates()

def eigvals(self):
return self.base.eigvals()

@property
def has_adjoint(self):
return self.base.has_adjoint

def adjoint(self):
return Conditional(self.meas_val, self.base.adjoint())


def cond(condition, true_fn, false_fn=None, elifs=()):
"""Quantum-compatible if-else conditionals --- condition quantum operations
Expand Down
4 changes: 2 additions & 2 deletions pennylane/transforms/defer_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,11 @@ def _add_control_gate(op, control_wires, reduce_postselected):
if value:
# Empty sampling branches can occur when using _postselected_items
if branch == ():
new_ops.append(op.then_op)
new_ops.append(op.base)
continue
qscript = qml.tape.make_qscript(
ctrl(
lambda: qml.apply(op.then_op), # pylint: disable=cell-var-from-loop
lambda: qml.apply(op.base), # pylint: disable=cell-var-from-loop
control=Wires(control),
control_values=branch,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/ops/functions/test_bind_new_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def test_conditional_ops(op, new_params, expected_op):
new_op = bind_new_parameters(cond_op, new_params)

assert isinstance(new_op, qml.ops.Conditional)
assert new_op.then_op == expected_op
assert new_op.base == expected_op
assert new_op.meas_val.measurements == [mp0]


Expand Down
97 changes: 97 additions & 0 deletions tests/ops/functions/test_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pennylane import numpy as npp
from pennylane.measurements import ExpectationMP
from pennylane.measurements.probs import ProbabilityMP
from pennylane.ops import Conditional
from pennylane.ops.functions.equal import _equal, assert_equal
from pennylane.ops.op_math import Controlled, SymbolicOp
from pennylane.templates.subroutines import ControlledSequence
Expand Down Expand Up @@ -1286,6 +1287,35 @@ def test_mid_measure(self):
qml.measurements.MidMeasureMP(wires=qml.wires.Wires([0, 1]), reset=True, id="test_id"),
)

def test_equal_measurement_value(self):
"""Test that MeasurementValue's are equal when their measurements are the same."""
mv1 = qml.measure(0)
mv2 = qml.measure(0)
# qml.equal of MidMeasureMP checks the id
mv2.measurements[0].id = mv1.measurements[0].id

assert qml.equal(mv1, mv1)
assert qml.equal(mv1, mv2)

def test_different_measurement_value(self):
"""Test that MeasurementValue's are different when their measurements are not the same."""
mv1 = qml.measure(0)
mv2 = qml.measure(1)
assert not qml.equal(mv1, mv2)

def test_composed_measurement_value(self):
"""Test that composition of MeasurementValue's are checked correctly."""
mv1 = qml.measure(0)
mv2 = qml.measure(1)
mv3 = qml.measure(0)
# qml.equal of MidMeasureMP checks the id
mv3.measurements[0].id = mv1.measurements[0].id

assert qml.equal(mv1 * mv2, mv2 * mv1)
assert qml.equal(mv1 + mv2, mv3 + mv2)
# NOTE: we are deliberatily just checking for measurements and not for processing_fn, such that two MeasurementValue objects composed from the same operators will be qml.equal
assert qml.equal(3 * mv1 + 1, 4 * mv3 + 2)

@pytest.mark.parametrize("mp_fn", [qml.probs, qml.sample, qml.counts])
def test_mv_list_as_op(self, mp_fn):
"""Test that MeasurementProcesses that measure a list of MeasurementValues check for equality
Expand Down Expand Up @@ -1648,6 +1678,73 @@ def test_adjoint_base_op_comparison_with_trainability(self):
assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)

@pytest.mark.parametrize(("wire1", "wire2", "res"), WIRES)
def test_conditional_base_operator_wire_comparison(self, wire1, wire2, res):
"""Test that equal compares operator wires for Conditional operators"""
m = qml.measure(0)
base1 = qml.PauliX(wire1)
base2 = qml.PauliX(wire2)
op1 = Conditional(m, base1)
op2 = Conditional(m, base2)
assert qml.equal(op1, op2) == res

@pytest.mark.parametrize(("wire1", "wire2", "res"), WIRES)
def test_conditional_measurement_value_wire_comparison(self, wire1, wire2, res):
"""Test that equal compares operator wires for Conditional operators"""
m1 = qml.measure(wire1)
m2 = qml.measure(wire2)
if wire1 == wire2:
# qml.equal checks id for MidMeasureMP, but here we only care about them acting on the same wire
m2.measurements[0].id = m1.measurements[0].id
base = qml.PauliX(wire2)
op1 = Conditional(m1, base)
op2 = Conditional(m2, base)
assert qml.equal(op1, op2) == res

@pytest.mark.parametrize(("base1", "base2", "res"), BASES)
def test_conditional_base_operator_comparison(self, base1, base2, res):
"""Test that equal compares base operators for Conditional operators"""
m = qml.measure(0)
op1 = Conditional(m, base1)
op2 = Conditional(m, base2)
assert qml.equal(op1, op2) == res

def test_conditional_comparison_with_tolerance(self):
"""Test that equal compares the parameters within a provided tolerance of the Conditional class."""
m = qml.measure(0)
base1 = qml.RX(1.2, wires=0)
base2 = qml.RX(1.2 + 1e-4, wires=0)
op1 = Conditional(m, base1)
op2 = Conditional(m, base2)

assert qml.equal(op1, op2, atol=1e-3, rtol=0)
assert not qml.equal(op1, op2, atol=1e-5, rtol=0)
assert qml.equal(op1, op2, atol=0, rtol=1e-3)
assert not qml.equal(op1, op2, atol=0, rtol=1e-5)

def test_conditional_base_op_comparison_with_interface(self):
"""Test that equal compares the parameters within a provided interface of the base operator of Conditional class."""
m = qml.measure(0)
base1 = qml.RX(1.2, wires=0)
base2 = qml.RX(npp.array(1.2), wires=0)
op1 = Conditional(m, base1)
op2 = Conditional(m, base2)

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=True, check_trainability=False)

def test_conditional_base_op_comparison_with_trainability(self):
"""Test that equal compares the parameters within a provided trainability of the base operator of Conditional class."""

m = qml.measure(0)
base1 = qml.RX(npp.array(1.2, requires_grad=False), wires=0)
base2 = qml.RX(npp.array(1.2, requires_grad=True), wires=0)
op1 = Conditional(m, base1)
op2 = Conditional(m, base2)

assert qml.equal(op1, op2, check_interface=False, check_trainability=False)
assert not qml.equal(op1, op2, check_interface=False, check_trainability=True)

@pytest.mark.parametrize("bases_bases_match", BASES)
@pytest.mark.parametrize("params_params_match", PARAMS)
def test_pow_comparison(self, bases_bases_match, params_params_match):
Expand Down
Loading
Loading