From b797b1d3f3af6b8f134c12c8704b5ceb6ab55fd8 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Mon, 3 Jun 2024 10:02:24 -0400 Subject: [PATCH 1/3] Capture a qnode into jaxpr (#5708) Co-authored-by: dwierichs Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com> --- doc/releases/changelog-dev.md | 3 +- pennylane/capture/__init__.py | 2 + pennylane/capture/capture_qnode.py | 167 ++++++++++++++++ pennylane/workflow/qnode.py | 9 +- tests/capture/test_capture_qnode.py | 298 ++++++++++++++++++++++++++++ 5 files changed, 476 insertions(+), 3 deletions(-) create mode 100644 pennylane/capture/capture_qnode.py create mode 100644 tests/capture/test_capture_qnode.py diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 3d90194d325..9826b0d28c9 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -108,9 +108,10 @@ `m = measure(0); qml.sample(m)`. [(#5673)](https://github.com/PennyLaneAI/pennylane/pull/5673) -* PennyLane operators and measurements can now automatically be captured as instructions in JAXPR. +* PennyLane operators, measurements, and QNodes can now automatically be captured as instructions in JAXPR. [(#5564)](https://github.com/PennyLaneAI/pennylane/pull/5564) [(#5511)](https://github.com/PennyLaneAI/pennylane/pull/5511) + [(#5708)](https://github.com/PennyLaneAI/pennylane/pull/5708) [(#5523)](https://github.com/PennyLaneAI/pennylane/pull/5523) * The `decompose` transform has an `error` kwarg to specify the type of error that should be raised, diff --git a/pennylane/capture/__init__.py b/pennylane/capture/__init__.py index 33edf07dbe6..3cbb39a3dfa 100644 --- a/pennylane/capture/__init__.py +++ b/pennylane/capture/__init__.py @@ -33,6 +33,7 @@ ~create_measurement_obs_primitive ~create_measurement_wires_primitive ~create_measurement_mcm_primitive + ~qnode_call To activate and deactivate the new PennyLane program capturing mechanism, use the switches ``qml.capture.enable`` and ``qml.capture.disable``. @@ -132,6 +133,7 @@ def _(*args, **kwargs): create_measurement_wires_primitive, create_measurement_mcm_primitive, ) +from .capture_qnode import qnode_call def __getattr__(key): diff --git a/pennylane/capture/capture_qnode.py b/pennylane/capture/capture_qnode.py new file mode 100644 index 00000000000..bbcd731e934 --- /dev/null +++ b/pennylane/capture/capture_qnode.py @@ -0,0 +1,167 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This submodule defines a capture compatible call to QNodes. +""" + +from functools import lru_cache, partial + +import pennylane as qml + +has_jax = True +try: + import jax +except ImportError: + has_jax = False + + +def _get_shapes_for(*measurements, shots=None, num_device_wires=0): + if jax.config.jax_enable_x64: + dtype_map = { + float: jax.numpy.float64, + int: jax.numpy.int64, + complex: jax.numpy.complex128, + } + else: + dtype_map = { + float: jax.numpy.float32, + int: jax.numpy.int32, + complex: jax.numpy.complex64, + } + + shapes = [] + if not shots: + shots = [None] + + for s in shots: + for m in measurements: + shape, dtype = m.abstract_eval(shots=s, num_device_wires=num_device_wires) + shapes.append(jax.core.ShapedArray(shape, dtype_map.get(dtype, dtype))) + return shapes + + +@lru_cache() +def _get_qnode_prim(): + if not has_jax: + return None + qnode_prim = jax.core.Primitive("qnode") + qnode_prim.multiple_results = True + + @qnode_prim.def_impl + def _(*args, shots, device, qnode_kwargs, qfunc_jaxpr): + def qfunc(*inner_args): + return jax.core.eval_jaxpr(qfunc_jaxpr.jaxpr, qfunc_jaxpr.consts, *inner_args) + + qnode = qml.QNode(qfunc, device, **qnode_kwargs) + return qnode._impl_call(*args, shots=shots) # pylint: disable=protected-access + + # pylint: disable=unused-argument + @qnode_prim.def_abstract_eval + def _(*args, shots, device, qnode_kwargs, qfunc_jaxpr): + mps = qfunc_jaxpr.out_avals + return _get_shapes_for(*mps, shots=shots, num_device_wires=len(device.wires)) + + return qnode_prim + + +# pylint: disable=protected-access +def _get_device_shots(device) -> "qml.measurements.Shots": + if isinstance(device, qml.devices.LegacyDevice): + if device._shot_vector: + return qml.measurements.Shots(device._raw_shot_sequence) + return qml.measurements.Shots(device.shots) + return device.shots + + +def qnode_call(qnode: "qml.QNode", *args, **kwargs) -> "qml.typing.Result": + """A capture compatible call to a QNode. This function is internally used by ``QNode.__call__``. + + Args: + qnode (QNode): a QNode + args: the arguments the QNode is called with + + Keyword Args: + Any keyword arguments accepted by the quantum function + + Returns: + qml.typing.Result: the result of a qnode execution + + **Example:** + + .. code-block:: python + + @qml.qnode(qml.device('lightning.qubit', wires=1)) + def circuit(x): + qml.RX(x, wires=0) + return qml.expval(qml.Z(0)), qml.probs() + + def f(x): + expval_z, probs = circuit(np.pi * x, shots=50) + return 2*expval_z + probs + + jaxpr = jax.make_jaxpr(f)(0.1) + print("jaxpr:") + print(jaxpr) + + res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.7) + print() + print("result:") + print(res) + + + .. code-block:: none + + jaxpr: + { lambda ; a:f32[]. let + b:f32[] = mul 3.141592653589793 a + c:f32[] d:f32[2] = qnode[ + device= + qfunc_jaxpr={ lambda ; e:f32[]. let + _:AbstractOperator() = RX[n_wires=1] e 0 + f:AbstractOperator() = PauliZ[n_wires=1] 0 + g:AbstractMeasurement(n_wires=None) = expval_obs f + h:AbstractMeasurement(n_wires=0) = probs_wires + in (g, h) } + qnode_kwargs={'diff_method': 'best', 'grad_on_execution': 'best', 'cache': False, 'cachesize': 10000, 'max_diff': 1, 'max_expansion': 10, 'device_vjp': False} + shots=Shots(total=50) + ] b + i:f32[] = mul 2.0 c + j:f32[2] = add i d + in (j,) } + + result: + [Array([-1.3 , -0.74], dtype=float32)] + + + """ + if "shots" in kwargs: + shots = qml.measurements.Shots(kwargs.pop("shots")) + else: + shots = _get_device_shots(qnode.device) + if shots.has_partitioned_shots: + # Questions over the pytrees and the nested result object shape + raise NotImplementedError("shot vectors are not yet supported with plxpr capture.") + + if not qnode.device.wires: + raise NotImplementedError("devices must specify wires for integration with plxpr capture.") + + qfunc = partial(qnode.func, **kwargs) if kwargs else qnode.func + + qfunc_jaxpr = jax.make_jaxpr(qfunc)(*args) + qnode_kwargs = {"diff_method": qnode.diff_method, **qnode.execute_kwargs} + qnode_prim = _get_qnode_prim() + + return qnode_prim.bind( + *args, shots=shots, device=qnode.device, qnode_kwargs=qnode_kwargs, qfunc_jaxpr=qfunc_jaxpr + ) diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 57d6e89c28b..947a636b08b 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -1,4 +1,4 @@ -# Copyright 2018-2021 Xanadu Quantum Technologies Inc. +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -1081,7 +1081,7 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml res, self._qfunc_output, self._tape.shots.has_partitioned_shots ) - def __call__(self, *args, **kwargs) -> qml.typing.Result: + def _impl_call(self, *args, **kwargs) -> qml.typing.Result: old_interface = self.interface if old_interface == "auto": @@ -1113,6 +1113,11 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result: return res + def __call__(self, *args, **kwargs) -> qml.typing.Result: + if qml.capture.enabled(): + return qml.capture.qnode_call(self, *args, **kwargs) + return self._impl_call(*args, **kwargs) + qnode = lambda device, **kwargs: functools.partial(QNode, device=device, **kwargs) qnode.__doc__ = QNode.__doc__ diff --git a/tests/capture/test_capture_qnode.py b/tests/capture/test_capture_qnode.py new file mode 100644 index 00000000000..1fe57c65e23 --- /dev/null +++ b/tests/capture/test_capture_qnode.py @@ -0,0 +1,298 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for capturing a qnode into jaxpr. +""" +from functools import partial + +# pylint: disable=protected-access +import pytest + +import pennylane as qml +from pennylane.capture.capture_qnode import _get_qnode_prim + +qnode_prim = _get_qnode_prim() + +pytestmark = pytest.mark.jax + +jax = pytest.importorskip("jax") + + +@pytest.fixture(autouse=True) +def enable_disable_plxpr(): + qml.capture.enable() + yield + qml.capture.disable() + + +@pytest.mark.parametrize("dev_name", ("default.qubit", "default.qubit.legacy")) +def test_error_if_shot_vector(dev_name): + """Test that a NotImplementedError is raised if a shot vector is provided.""" + + dev = qml.device(dev_name, wires=1, shots=(50, 50)) + + @qml.qnode(dev) + def circuit(): + return qml.sample() + + with pytest.raises(NotImplementedError, match="shot vectors are not yet supported"): + jax.make_jaxpr(circuit)() + + with pytest.raises(NotImplementedError, match="shot vectors are not yet supported"): + circuit() + + jax.make_jaxpr(partial(circuit, shots=50))() # should run fine + res = circuit(shots=50) + assert qml.math.allclose(res, jax.numpy.zeros((50,))) + + +@pytest.mark.parametrize("dev_name", ("default.qubit", "default.qubit.legacy")) +def test_error_if_overridden_shot_vector(dev_name): + """Test that a NotImplementedError is raised if a shot vector is provided on call.""" + + dev = qml.device(dev_name, wires=1) + + @qml.qnode(dev) + def circuit(): + return qml.sample() + + with pytest.raises(NotImplementedError, match="shot vectors are not yet supported"): + jax.make_jaxpr(partial(circuit, shots=(1, 1, 1)))() + + +def test_error_if_no_device_wires(): + """Test that a NotImplementedError is raised if the device does not provide wires.""" + + dev = qml.device("default.qubit") + + @qml.qnode(dev) + def circuit(): + return qml.sample() + + with pytest.raises(NotImplementedError, match="devices must specify wires"): + jax.make_jaxpr(circuit)() + + with pytest.raises(NotImplementedError, match="devices must specify wires"): + circuit() + + +@pytest.mark.parametrize("x64_mode", (True, False)) +def test_simple_qnode(x64_mode): + """Test capturing a qnode for a simple use.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + dev = qml.device("default.qubit", wires=4) + + @qml.qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + return qml.expval(qml.Z(0)) + + res = circuit(0.5) + assert qml.math.allclose(res, jax.numpy.cos(0.5)) + + jaxpr = jax.make_jaxpr(circuit)(0.5) + + assert len(jaxpr.eqns) == 1 + eqn0 = jaxpr.eqns[0] + + fdtype = jax.numpy.float64 if x64_mode else jax.numpy.float32 + + assert jaxpr.in_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] + + assert eqn0.primitive == qnode_prim + assert eqn0.invars[0].aval == jaxpr.in_avals[0] + assert jaxpr.out_avals[0] == jax.core.ShapedArray((), fdtype) + + assert eqn0.params["device"] == dev + assert eqn0.params["shots"] == qml.measurements.Shots(None) + expected_kwargs = {"diff_method": "best"} + expected_kwargs.update(circuit.execute_kwargs) + assert eqn0.params["qnode_kwargs"] == expected_kwargs + + qfunc_jaxpr = eqn0.params["qfunc_jaxpr"] + assert len(qfunc_jaxpr.eqns) == 3 + assert qfunc_jaxpr.eqns[0].primitive == qml.RX._primitive + assert qfunc_jaxpr.eqns[1].primitive == qml.Z._primitive + assert qfunc_jaxpr.eqns[2].primitive == qml.measurements.ExpectationMP._obs_primitive + + assert len(eqn0.outvars) == 1 + assert eqn0.outvars[0].aval == jax.core.ShapedArray((), fdtype) + + output = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.5) + assert qml.math.allclose(output[0], jax.numpy.cos(0.5)) + + jax.config.update("jax_enable_x64", initial_mode) + + +@pytest.mark.parametrize("dev_name", ("default.qubit", "default.qubit.legacy")) +@pytest.mark.parametrize("x64_mode", (True, False)) +def test_overriding_shots(dev_name, x64_mode): + """Test that the number of shots can be overridden on call.""" + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + dev = qml.device(dev_name, wires=1) + + @qml.qnode(dev) + def circuit(): + return qml.sample() + + jaxpr = jax.make_jaxpr(partial(circuit, shots=50))() + assert len(jaxpr.eqns) == 1 + eqn0 = jaxpr.eqns[0] + + assert eqn0.primitive == qnode_prim + assert eqn0.params["device"] == dev + assert eqn0.params["shots"] == qml.measurements.Shots(50) + assert ( + eqn0.params["qfunc_jaxpr"].eqns[0].primitive == qml.measurements.SampleMP._wires_primitive + ) + + assert eqn0.outvars[0].aval == jax.core.ShapedArray( + (50,), jax.numpy.int64 if x64_mode else jax.numpy.int32 + ) + + res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts) + assert qml.math.allclose(res, jax.numpy.zeros((50,))) + + jax.config.update("jax_enable_x64", initial_mode) + + +def test_providing_keyword_argument(): + """Test that keyword arguments can be provided to the qnode.""" + + @qml.qnode(qml.device("default.qubit", wires=1)) + def circuit(*, n_iterations=0): + for _ in range(n_iterations): + qml.X(0) + return qml.probs() + + jaxpr = jax.make_jaxpr(partial(circuit, n_iterations=3))() + + assert jaxpr.eqns[0].primitive == qnode_prim + + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + for i in range(3): + assert qfunc_jaxpr.eqns[i].primitive == qml.PauliX._primitive + assert len(qfunc_jaxpr.eqns) == 4 + + res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts) + assert qml.math.allclose(res, jax.numpy.array([0, 1])) + + res2 = circuit(n_iterations=4) + assert qml.math.allclose(res2, jax.numpy.array([1, 0])) + + +@pytest.mark.parametrize("x64_mode", (True, False)) +def test_multiple_measurements(x64_mode): + """Test that the qnode can return multiple measurements.""" + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + @qml.qnode(qml.device("default.qubit", wires=3, shots=50)) + def circuit(): + return qml.sample(), qml.probs(wires=(0, 1)), qml.expval(qml.Z(0)) + + jaxpr = jax.make_jaxpr(circuit)() + + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + + assert qfunc_jaxpr.eqns[0].primitive == qml.measurements.SampleMP._wires_primitive + assert qfunc_jaxpr.eqns[1].primitive == qml.measurements.ProbabilityMP._wires_primitive + assert qfunc_jaxpr.eqns[2].primitive == qml.Z._primitive + assert qfunc_jaxpr.eqns[3].primitive == qml.measurements.ExpectationMP._obs_primitive + + assert jaxpr.out_avals[0] == jax.core.ShapedArray( + (50, 3), jax.numpy.int64 if x64_mode else jax.numpy.int32 + ) + assert jaxpr.out_avals[1] == jax.core.ShapedArray( + (4,), jax.numpy.float64 if x64_mode else jax.numpy.float32 + ) + assert jaxpr.out_avals[2] == jax.core.ShapedArray( + (), jax.numpy.float64 if x64_mode else jax.numpy.float32 + ) + + res1, res2, res3 = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts) + assert qml.math.allclose(res1, jax.numpy.zeros((50, 3))) + assert qml.math.allclose(res2, jax.numpy.array([1, 0, 0, 0])) + assert qml.math.allclose(res3, 1.0) + + res1, res2, res3 = circuit() + assert qml.math.allclose(res1, jax.numpy.zeros((50, 3))) + assert qml.math.allclose(res2, jax.numpy.array([1, 0, 0, 0])) + assert qml.math.allclose(res3, 1.0) + + jax.config.update("jax_enable_x64", initial_mode) + + +@pytest.mark.parametrize("x64_mode", (True, False)) +def test_complex_return_types(x64_mode): + """Test returning measurements with complex values.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + @qml.qnode(qml.device("default.qubit", wires=3)) + def circuit(): + return qml.state(), qml.density_matrix(wires=(0, 1)) + + jaxpr = jax.make_jaxpr(circuit)() + + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + + assert qfunc_jaxpr.eqns[0].primitive == qml.measurements.StateMP._wires_primitive + assert qfunc_jaxpr.eqns[1].primitive == qml.measurements.DensityMatrixMP._wires_primitive + + assert jaxpr.out_avals[0] == jax.core.ShapedArray( + (8,), jax.numpy.complex128 if x64_mode else jax.numpy.complex64 + ) + assert jaxpr.out_avals[1] == jax.core.ShapedArray( + (4, 4), jax.numpy.complex128 if x64_mode else jax.numpy.complex64 + ) + + jax.config.update("jax_enable_x64", initial_mode) + + +def test_capture_qnode_kwargs(): + """Test that qnode kwargs are captured as parameters.""" + + dev = qml.device("default.qubit", wires=3) + + @qml.qnode( + dev, + diff_method="parameter-shift", + grad_on_execution=False, + cache=True, + cachesize=10, + max_diff=2, + ) + def circuit(): + return qml.expval(qml.Z(0)) + + jaxpr = jax.make_jaxpr(circuit)() + + assert jaxpr.eqns[0].primitive == qnode_prim + expected = { + "diff_method": "parameter-shift", + "grad_on_execution": False, + "cache": True, + "cachesize": 10, + "max_diff": 2, + "max_expansion": 10, + "device_vjp": False, + } + assert jaxpr.eqns[0].params["qnode_kwargs"] == expected From c83d98c243756d7cba95368f154e16f8737c6593 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 3 Jun 2024 15:46:13 +0000 Subject: [PATCH 2/3] Update stable dependency files (#5703) Automatic update of stable requirement files to snapshot valid python environments. Because bots are not able to trigger CI on their own, please do so by pushing an empty commit to this branch using the following command: ``` git commit --allow-empty -m 'trigger ci' ``` Alternatively, wait for this branch to be out-of-date with master, then just use the "Update branch" button! Note that it is expected that the PennyLane-Lightning repo is a version ahead of the release, because the version number is taken from the dev branch. Trying to `pip install` from the files will fail until that major version of Lightning is released. If pip install fails with a not found error when installing because of this, it can be fixed by manually downgrading the PennyLane-Lightning version number in the file by 1 version and trying again. --------- Co-authored-by: GitHub Actions Bot <> Co-authored-by: Christina Lee --- .github/stable/all_interfaces.txt | 36 ++++++++++++--------- .github/stable/core.txt | 29 ++++++++++------- .github/stable/doc.txt | 18 +++++------ .github/stable/external.txt | 54 ++++++++++++++++++------------- .github/stable/jax.txt | 31 ++++++++++-------- .github/stable/tf.txt | 32 ++++++++++-------- .github/stable/torch.txt | 31 ++++++++++-------- 7 files changed, 132 insertions(+), 99 deletions(-) diff --git a/.github/stable/all_interfaces.txt b/.github/stable/all_interfaces.txt index 77b45a0d419..05f6fb152b4 100644 --- a/.github/stable/all_interfaces.txt +++ b/.github/stable/all_interfaces.txt @@ -1,17 +1,18 @@ absl-py==2.1.0 appdirs==1.4.4 +astroid==2.6.6 astunparse==1.6.3 autograd==1.6.2 -autoray==0.6.9 +autoray==0.6.12 black==24.4.2 cachetools==5.3.3 -certifi==2024.2.2 +certifi==2024.6.2 cfgv==3.4.0 charset-normalizer==3.3.2 -clarabel==0.7.1 +clarabel==0.9.0 click==8.1.7 contourpy==1.2.1 -coverage==7.5.1 +coverage==7.5.3 cvxopt==1.3.2 cvxpy==1.5.1 cycler==0.12.1 @@ -22,12 +23,12 @@ execnet==2.1.1 filelock==3.14.0 flaky==3.8.1 flatbuffers==24.3.25 -fonttools==4.51.0 +fonttools==4.53.0 fsspec==2024.5.0 future==1.0.0 gast==0.5.4 google-pasta==0.2.0 -grpcio==1.63.0 +grpcio==1.64.0 h5py==3.11.0 identify==2.5.36 idna==3.7 @@ -40,22 +41,24 @@ jaxlib==0.4.23 Jinja2==3.1.4 keras==3.3.3 kiwisolver==1.4.5 +lazy-object-proxy==1.10.0 libclang==18.1.1 Markdown==3.6 markdown-it-py==3.0.0 MarkupSafe==2.1.5 matplotlib==3.9.0 +mccabe==0.6.1 mdurl==0.1.2 ml-dtypes==0.3.2 mpmath==1.3.0 mypy-extensions==1.0.0 namex==0.0.8 networkx==3.2.1 -nodeenv==1.8.0 +nodeenv==1.9.0 numpy==1.26.4 opt-einsum==3.3.0 optree==0.11.0 -osqp==0.6.5 +osqp==0.6.7 packaging==24.0 pathspec==0.12.1 PennyLane_Lightning==0.37.0 @@ -67,8 +70,9 @@ protobuf==4.25.3 py==1.11.0 py-cpuinfo==9.0.0 Pygments==2.18.0 +pylint==2.7.4 pyparsing==3.1.2 -pytest==8.2.0 +pytest==8.2.1 pytest-benchmark==4.0.0 pytest-cov==5.0.0 pytest-forked==1.6.0 @@ -78,14 +82,14 @@ python-dateutil==2.9.0.post0 pytorch-triton-rocm==2.3.0 PyYAML==6.0.1 qdldl==0.1.7.post2 -requests==2.31.0 +requests==2.32.3 rich==13.7.1 rustworkx==0.14.2 -scipy==1.11.4 -scs==3.2.4.post1 +scipy==1.12.0 +scs==3.2.4.post2 semantic-version==2.10.0 six==1.16.0 -sympy==1.12 +sympy==1.12.1 tensorboard==2.16.2 tensorboard-data-server==0.7.2 tensorflow==2.16.1 @@ -95,9 +99,9 @@ tf_keras==2.16.0 toml==0.10.2 tomli==2.0.1 torch==2.3.0+rocm6.0 -typing_extensions==4.11.0 +typing_extensions==4.12.1 urllib3==2.2.1 virtualenv==20.26.2 Werkzeug==3.0.3 -wrapt==1.16.0 -zipp==3.18.2 +wrapt==1.12.1 +zipp==3.19.1 diff --git a/.github/stable/core.txt b/.github/stable/core.txt index def68dbf5b7..2c9830a6d42 100644 --- a/.github/stable/core.txt +++ b/.github/stable/core.txt @@ -1,15 +1,16 @@ appdirs==1.4.4 +astroid==2.6.6 autograd==1.6.2 -autoray==0.6.9 +autoray==0.6.12 black==24.4.2 cachetools==5.3.3 -certifi==2024.2.2 +certifi==2024.6.2 cfgv==3.4.0 charset-normalizer==3.3.2 -clarabel==0.7.1 +clarabel==0.9.0 click==8.1.7 contourpy==1.2.1 -coverage==7.5.1 +coverage==7.5.3 cvxopt==1.3.2 cvxpy==1.5.1 cycler==0.12.1 @@ -19,7 +20,7 @@ exceptiongroup==1.2.1 execnet==2.1.1 filelock==3.14.0 flaky==3.8.1 -fonttools==4.51.0 +fonttools==4.53.0 future==1.0.0 identify==2.5.36 idna==3.7 @@ -27,12 +28,14 @@ importlib_resources==6.4.0 iniconfig==2.0.0 isort==5.13.2 kiwisolver==1.4.5 +lazy-object-proxy==1.10.0 matplotlib==3.9.0 +mccabe==0.6.1 mypy-extensions==1.0.0 networkx==3.2.1 -nodeenv==1.8.0 +nodeenv==1.9.0 numpy==1.26.4 -osqp==0.6.5 +osqp==0.6.7 packaging==24.0 pathspec==0.12.1 PennyLane_Lightning==0.37.0 @@ -42,8 +45,9 @@ pluggy==1.5.0 pre-commit==3.7.1 py==1.11.0 py-cpuinfo==9.0.0 +pylint==2.7.4 pyparsing==3.1.2 -pytest==8.2.0 +pytest==8.2.1 pytest-benchmark==4.0.0 pytest-cov==5.0.0 pytest-forked==1.6.0 @@ -53,15 +57,16 @@ pytest-xdist==3.6.1 python-dateutil==2.9.0.post0 PyYAML==6.0.1 qdldl==0.1.7.post2 -requests==2.31.0 +requests==2.32.3 rustworkx==0.14.2 scipy==1.11.4 -scs==3.2.4.post1 +scs==3.2.4.post2 semantic-version==2.10.0 six==1.16.0 toml==0.10.2 tomli==2.0.1 -typing_extensions==4.11.0 +typing_extensions==4.12.1 urllib3==2.2.1 virtualenv==20.26.2 -zipp==3.18.2 +wrapt==1.12.1 +zipp==3.19.1 diff --git a/.github/stable/doc.txt b/.github/stable/doc.txt index a7492e17485..e4c039b43e6 100644 --- a/.github/stable/doc.txt +++ b/.github/stable/doc.txt @@ -10,7 +10,7 @@ autograd==1.6.2 autoray==0.6.12 Babel==2.15.0 cachetools==5.3.3 -certifi==2024.2.2 +certifi==2024.6.2 charset-normalizer==3.3.2 cirq-core==1.3.0 contourpy==1.2.1 @@ -20,7 +20,7 @@ docutils==0.16 duet==0.2.9 exceptiongroup==1.2.1 flatbuffers==24.3.25 -fonttools==4.51.0 +fonttools==4.53.0 frozenlist==1.4.1 fsspec==2024.5.0 future==1.0.0 @@ -29,7 +29,7 @@ google-auth==2.29.0 google-auth-oauthlib==0.4.6 google-pasta==0.2.0 graphviz==0.20.3 -grpcio==1.63.0 +grpcio==1.64.0 h5py==3.11.0 idna==3.7 imagesize==1.4.1 @@ -72,8 +72,8 @@ pybtex-docutils==1.0.3 Pygments==2.18.0 pygments-github-lexers==0.0.5 pyparsing==3.1.2 -pyscf==2.5.0 -pytest==8.2.0 +pyscf==2.6.0 +pytest==8.2.1 python-dateutil==2.9.0.post0 pytz==2024.1 PyYAML==6.0.1 @@ -81,7 +81,7 @@ requests==2.28.2 requests-oauthlib==2.0.0 rsa==4.9 rustworkx==0.12.1 -scipy==1.13.0 +scipy==1.13.1 semantic-version==2.10.0 six==1.16.0 snowballstemmer==2.2.0 @@ -98,7 +98,7 @@ sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 sphinxext-opengraph==0.6.3 -sympy==1.12 +sympy==1.12.1 tensorboard==2.11.2 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.1 @@ -111,11 +111,11 @@ toml==0.10.2 tomli==2.0.1 torch==1.9.0+cpu tqdm==4.66.4 -typing_extensions==4.11.0 +typing_extensions==4.12.1 tzdata==2024.1 urllib3==1.26.18 Werkzeug==3.0.3 wrapt==1.16.0 xanadu-sphinx-theme==0.5.0 yarl==1.9.4 -zipp==3.18.2 +zipp==3.19.1 diff --git a/.github/stable/external.txt b/.github/stable/external.txt index febc739d8c6..89e5d1cf522 100644 --- a/.github/stable/external.txt +++ b/.github/stable/external.txt @@ -1,32 +1,35 @@ absl-py==2.1.0 -anyio==4.3.0 +anyio==4.4.0 appdirs==1.4.4 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 arrow==1.3.0 +astroid==2.6.6 asttokens==2.4.1 astunparse==1.6.3 async-lru==2.0.4 attrs==23.2.0 autograd==1.6.2 -autoray==0.6.9 +autoray==0.6.12 Babel==2.15.0 beautifulsoup4==4.12.3 black==24.4.2 bleach==6.1.0 cachetools==5.3.3 -certifi==2024.2.2 +certifi==2024.6.2 cffi==1.16.0 cfgv==3.4.0 charset-normalizer==3.3.2 -clarabel==0.7.1 +clarabel==0.9.0 click==8.1.7 comm==0.2.2 contourpy==1.2.1 -coverage==7.5.1 +cotengra==0.6.2 +coverage==7.5.3 cvxopt==1.3.2 cvxpy==1.5.1 cycler==0.12.1 +cytoolz==0.12.3 debugpy==1.8.1 decorator==5.1.1 defusedxml==0.7.1 @@ -40,12 +43,12 @@ fastjsonschema==2.19.1 filelock==3.14.0 flaky==3.8.1 flatbuffers==24.3.25 -fonttools==4.51.0 +fonttools==4.53.0 fqdn==1.5.1 future==1.0.0 gast==0.5.4 google-pasta==0.2.0 -grpcio==1.63.0 +grpcio==1.64.0 h11==0.14.0 h5py==3.11.0 httpcore==1.0.5 @@ -71,23 +74,26 @@ jsonschema==4.22.0 jsonschema-specifications==2023.12.1 jupyter-events==0.10.0 jupyter-lsp==2.2.5 -jupyter_client==8.6.1 +jupyter_client==8.6.2 jupyter_core==5.7.2 -jupyter_server==2.14.0 +jupyter_server==2.14.1 jupyter_server_terminals==0.5.3 -jupyterlab==4.1.8 +jupyterlab==4.2.1 jupyterlab-widgets==1.1.7 jupyterlab_pygments==0.3.0 -jupyterlab_server==2.27.1 +jupyterlab_server==2.27.2 keras==3.3.3 kiwisolver==1.4.5 lark==1.1.9 +lazy-object-proxy==1.10.0 libclang==18.1.1 +llvmlite==0.42.0 Markdown==3.6 markdown-it-py==3.0.0 MarkupSafe==2.1.5 matplotlib==3.9.0 matplotlib-inline==0.1.7 +mccabe==0.6.1 mdurl==0.1.2 mistune==3.0.2 ml-dtypes==0.3.2 @@ -98,13 +104,14 @@ nbconvert==7.16.4 nbformat==5.10.4 nest-asyncio==1.6.0 networkx==3.2.1 -nodeenv==1.8.0 -notebook==7.1.3 +nodeenv==1.9.0 +notebook==7.2.0 notebook_shim==0.2.4 +numba==0.59.1 numpy==1.26.4 opt-einsum==3.3.0 optree==0.11.0 -osqp==0.6.5 +osqp==0.6.7 overrides==7.7.0 packaging==24.0 pandocfilters==1.5.1 @@ -118,7 +125,7 @@ platformdirs==4.2.2 pluggy==1.5.0 pre-commit==3.7.1 prometheus_client==0.20.0 -prompt-toolkit==3.0.43 +prompt_toolkit==3.0.45 protobuf==4.25.3 psutil==5.9.8 ptyprocess==0.7.0 @@ -127,9 +134,10 @@ py==1.11.0 py-cpuinfo==9.0.0 pycparser==2.22 Pygments==2.18.0 +pylint==2.7.4 pyparsing==3.1.2 pyperclip==1.8.2 -pytest==8.2.0 +pytest==8.2.1 pytest-benchmark==4.0.0 pytest-cov==5.0.0 pytest-forked==1.6.0 @@ -141,15 +149,16 @@ PyYAML==6.0.1 pyzmq==26.0.3 pyzx==0.8.0 qdldl==0.1.7.post2 +quimb==1.8.1 referencing==0.35.1 -requests==2.31.0 +requests==2.32.3 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 rich==13.7.1 rpds-py==0.18.1 rustworkx==0.14.2 -scipy==1.11.4 -scs==3.2.4.post1 +scipy==1.12.0 +scs==3.2.4.post2 semantic-version==2.10.0 Send2Trash==1.8.3 six==1.16.0 @@ -168,11 +177,12 @@ tinycss2==1.3.0 toml==0.10.2 tomli==2.0.1 tomlkit==0.12.5 +toolz==0.12.1 tornado==6.4 tqdm==4.66.4 traitlets==5.14.3 types-python-dateutil==2.9.0.20240316 -typing_extensions==4.11.0 +typing_extensions==4.12.1 uri-template==1.3.0 urllib3==2.2.1 virtualenv==20.26.2 @@ -182,5 +192,5 @@ webencodings==0.5.1 websocket-client==1.8.0 Werkzeug==3.0.3 widgetsnbextension==3.6.6 -wrapt==1.16.0 -zipp==3.18.2 +wrapt==1.12.1 +zipp==3.19.1 diff --git a/.github/stable/jax.txt b/.github/stable/jax.txt index d3181d8b927..4978dada90c 100644 --- a/.github/stable/jax.txt +++ b/.github/stable/jax.txt @@ -1,15 +1,16 @@ appdirs==1.4.4 +astroid==2.6.6 autograd==1.6.2 -autoray==0.6.9 +autoray==0.6.12 black==24.4.2 cachetools==5.3.3 -certifi==2024.2.2 +certifi==2024.6.2 cfgv==3.4.0 charset-normalizer==3.3.2 -clarabel==0.7.1 +clarabel==0.9.0 click==8.1.7 contourpy==1.2.1 -coverage==7.5.1 +coverage==7.5.3 cvxopt==1.3.2 cvxpy==1.5.1 cycler==0.12.1 @@ -19,7 +20,7 @@ exceptiongroup==1.2.1 execnet==2.1.1 filelock==3.14.0 flaky==3.8.1 -fonttools==4.51.0 +fonttools==4.53.0 future==1.0.0 identify==2.5.36 idna==3.7 @@ -30,14 +31,16 @@ isort==5.13.2 jax==0.4.23 jaxlib==0.4.23 kiwisolver==1.4.5 +lazy-object-proxy==1.10.0 matplotlib==3.9.0 +mccabe==0.6.1 ml-dtypes==0.4.0 mypy-extensions==1.0.0 networkx==3.2.1 -nodeenv==1.8.0 +nodeenv==1.9.0 numpy==1.26.4 opt-einsum==3.3.0 -osqp==0.6.5 +osqp==0.6.7 packaging==24.0 pathspec==0.12.1 PennyLane_Lightning==0.37.0 @@ -47,8 +50,9 @@ pluggy==1.5.0 pre-commit==3.7.1 py==1.11.0 py-cpuinfo==9.0.0 +pylint==2.7.4 pyparsing==3.1.2 -pytest==8.2.0 +pytest==8.2.1 pytest-benchmark==4.0.0 pytest-cov==5.0.0 pytest-forked==1.6.0 @@ -58,15 +62,16 @@ pytest-xdist==3.6.1 python-dateutil==2.9.0.post0 PyYAML==6.0.1 qdldl==0.1.7.post2 -requests==2.31.0 +requests==2.32.3 rustworkx==0.14.2 -scipy==1.11.4 -scs==3.2.4.post1 +scipy==1.12.0 +scs==3.2.4.post2 semantic-version==2.10.0 six==1.16.0 toml==0.10.2 tomli==2.0.1 -typing_extensions==4.11.0 +typing_extensions==4.12.1 urllib3==2.2.1 virtualenv==20.26.2 -zipp==3.18.2 +wrapt==1.12.1 +zipp==3.19.1 diff --git a/.github/stable/tf.txt b/.github/stable/tf.txt index 41cdf783405..1bf9daf1825 100644 --- a/.github/stable/tf.txt +++ b/.github/stable/tf.txt @@ -1,17 +1,18 @@ absl-py==2.1.0 appdirs==1.4.4 +astroid==2.6.6 astunparse==1.6.3 autograd==1.6.2 -autoray==0.6.9 +autoray==0.6.12 black==24.4.2 cachetools==5.3.3 -certifi==2024.2.2 +certifi==2024.6.2 cfgv==3.4.0 charset-normalizer==3.3.2 -clarabel==0.7.1 +clarabel==0.9.0 click==8.1.7 contourpy==1.2.1 -coverage==7.5.1 +coverage==7.5.3 cvxopt==1.3.2 cvxpy==1.5.1 cycler==0.12.1 @@ -22,11 +23,11 @@ execnet==2.1.1 filelock==3.14.0 flaky==3.8.1 flatbuffers==24.3.25 -fonttools==4.51.0 +fonttools==4.53.0 future==1.0.0 gast==0.5.4 google-pasta==0.2.0 -grpcio==1.63.0 +grpcio==1.64.0 h5py==3.11.0 identify==2.5.36 idna==3.7 @@ -36,21 +37,23 @@ iniconfig==2.0.0 isort==5.13.2 keras==3.3.3 kiwisolver==1.4.5 +lazy-object-proxy==1.10.0 libclang==18.1.1 Markdown==3.6 markdown-it-py==3.0.0 MarkupSafe==2.1.5 matplotlib==3.9.0 +mccabe==0.6.1 mdurl==0.1.2 ml-dtypes==0.3.2 mypy-extensions==1.0.0 namex==0.0.8 networkx==3.2.1 -nodeenv==1.8.0 +nodeenv==1.9.0 numpy==1.26.4 opt-einsum==3.3.0 optree==0.11.0 -osqp==0.6.5 +osqp==0.6.7 packaging==24.0 pathspec==0.12.1 PennyLane_Lightning==0.37.0 @@ -62,8 +65,9 @@ protobuf==4.25.3 py==1.11.0 py-cpuinfo==9.0.0 Pygments==2.18.0 +pylint==2.7.4 pyparsing==3.1.2 -pytest==8.2.0 +pytest==8.2.1 pytest-benchmark==4.0.0 pytest-cov==5.0.0 pytest-forked==1.6.0 @@ -73,11 +77,11 @@ pytest-xdist==3.6.1 python-dateutil==2.9.0.post0 PyYAML==6.0.1 qdldl==0.1.7.post2 -requests==2.31.0 +requests==2.32.3 rich==13.7.1 rustworkx==0.14.2 scipy==1.11.4 -scs==3.2.4.post1 +scs==3.2.4.post2 semantic-version==2.10.0 six==1.16.0 tensorboard==2.16.2 @@ -88,9 +92,9 @@ termcolor==2.4.0 tf_keras==2.16.0 toml==0.10.2 tomli==2.0.1 -typing_extensions==4.11.0 +typing_extensions==4.12.1 urllib3==2.2.1 virtualenv==20.26.2 Werkzeug==3.0.3 -wrapt==1.16.0 -zipp==3.18.2 +wrapt==1.12.1 +zipp==3.19.1 diff --git a/.github/stable/torch.txt b/.github/stable/torch.txt index 1acf7336c96..18a7db0f881 100644 --- a/.github/stable/torch.txt +++ b/.github/stable/torch.txt @@ -1,15 +1,16 @@ appdirs==1.4.4 +astroid==2.6.6 autograd==1.6.2 -autoray==0.6.9 +autoray==0.6.12 black==24.4.2 cachetools==5.3.3 -certifi==2024.2.2 +certifi==2024.6.2 cfgv==3.4.0 charset-normalizer==3.3.2 -clarabel==0.7.1 +clarabel==0.9.0 click==8.1.7 contourpy==1.2.1 -coverage==7.5.1 +coverage==7.5.3 cvxopt==1.3.2 cvxpy==1.5.1 cycler==0.12.1 @@ -19,7 +20,7 @@ exceptiongroup==1.2.1 execnet==2.1.1 filelock==3.14.0 flaky==3.8.1 -fonttools==4.51.0 +fonttools==4.53.0 fsspec==2024.5.0 future==1.0.0 identify==2.5.36 @@ -29,14 +30,16 @@ iniconfig==2.0.0 isort==5.13.2 Jinja2==3.1.4 kiwisolver==1.4.5 +lazy-object-proxy==1.10.0 MarkupSafe==2.1.5 matplotlib==3.9.0 +mccabe==0.6.1 mpmath==1.3.0 mypy-extensions==1.0.0 networkx==3.2.1 -nodeenv==1.8.0 +nodeenv==1.9.0 numpy==1.26.4 -osqp==0.6.5 +osqp==0.6.7 packaging==24.0 pathspec==0.12.1 PennyLane_Lightning==0.37.0 @@ -46,8 +49,9 @@ pluggy==1.5.0 pre-commit==3.7.1 py==1.11.0 py-cpuinfo==9.0.0 +pylint==2.7.4 pyparsing==3.1.2 -pytest==8.2.0 +pytest==8.2.1 pytest-benchmark==4.0.0 pytest-cov==5.0.0 pytest-forked==1.6.0 @@ -57,17 +61,18 @@ python-dateutil==2.9.0.post0 pytorch-triton-rocm==2.3.0 PyYAML==6.0.1 qdldl==0.1.7.post2 -requests==2.31.0 +requests==2.32.3 rustworkx==0.14.2 scipy==1.11.4 -scs==3.2.4.post1 +scs==3.2.4.post2 semantic-version==2.10.0 six==1.16.0 -sympy==1.12 +sympy==1.12.1 toml==0.10.2 tomli==2.0.1 torch==2.3.0+rocm6.0 -typing_extensions==4.11.0 +typing_extensions==4.12.1 urllib3==2.2.1 virtualenv==20.26.2 -zipp==3.18.2 +wrapt==1.12.1 +zipp==3.19.1 From 456e47447e72fcce15c04496a0656d64ebf9d0d3 Mon Sep 17 00:00:00 2001 From: Gabriel Bottrill <78718539+Gabriel-Bottrill@users.noreply.github.com> Date: Mon, 3 Jun 2024 13:52:54 -0700 Subject: [PATCH 3/3] Added |2> -> |1> amplitude damping to amplitude damping. (#5757) **Context:** I recently added `qml.QutritAmplitudeDamping` based on specifications from this [paper](https://doi.org/10.48550/arXiv.1905.10481). This paper only considers the transition from |2> to |0>, in reality on a super conducting device it is more likely to have a |2> to |1> transition. This PR adds a variable and a new Kraus matrix to `qml.QutritAmplitudeDamping` to allow for the simulation of the |2> to |1> transition. **Description of the Change:** Adds a new variable, \gamma_3, to `qml.QutritAmplitudeDamping` that defines the new Kraus matrix K_3. This new Kraus operator allows for the simulation of |2> to |1> relaxations. **Benefits:** Allows for simulation of |2> to |1> relaxation. This will make for more accurate simulations and is useful for measurement error for qutrits. **Possible Drawbacks:** This changes `qml.QutritAmplitudeDamping` and adds an extra parameter, the naming may be a bit more confusing. **Related GitHub Issues:** N/A --------- Co-authored-by: Olivia Di Matteo <2068515+glassnotes@users.noreply.github.com> --- doc/releases/changelog-dev.md | 7 +- pennylane/__init__.py | 3 + pennylane/ops/qutrit/channel.py | 61 +++++++++---- .../test_qutrit_mixed_preprocessing.py | 1 + tests/ops/qutrit/test_qutrit_channel_ops.py | 89 ++++++++++++------- 5 files changed, 109 insertions(+), 52 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 9826b0d28c9..6d198b3b145 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -152,6 +152,10 @@ * ``qml.QutritDepolarizingChannel`` has been added, allowing for depolarizing noise to be simulated on the `default.qutrit.mixed` device. [(#5502)](https://github.com/PennyLaneAI/pennylane/pull/5502) +* `qml.QutritAmplitudeDamping` channel has been added, allowing for noise processes modelled by amplitude damping to be simulated on the `default.qutrit.mixed` device. + [(#5503)](https://github.com/PennyLaneAI/pennylane/pull/5503) + [(#5757)](https://github.com/PennyLaneAI/pennylane/pull/5757) +

Breaking changes 💔

* A custom decomposition can no longer be provided to `QDrift`. Instead, apply the operations in your custom @@ -177,9 +181,6 @@ * `Controlled.wires` does not include `self.work_wires` anymore. That can be accessed separately through `Controlled.work_wires`. Consequently, `Controlled.active_wires` has been removed in favour of the more common `Controlled.wires`. [(#5728)](https://github.com/PennyLaneAI/pennylane/pull/5728) - -* `qml.QutritAmplitudeDamping` channel has been added, allowing for noise processes modelled by amplitude damping to be simulated on the `default.qutrit.mixed` device. - [(#5503)](https://github.com/PennyLaneAI/pennylane/pull/5503)

Deprecations 👋

diff --git a/pennylane/__init__.py b/pennylane/__init__.py index 80419832ae0..139ebc29222 100644 --- a/pennylane/__init__.py +++ b/pennylane/__init__.py @@ -218,6 +218,9 @@ def device(name, *args, **kwargs): * :mod:`'default.qutrit' `: a simple state simulator of qutrit-based quantum circuit architectures. + * :mod:`'default.qutrit.mixed' `: a + mixed-state simulator of qutrit-based quantum circuit architectures. + * :mod:`'default.gaussian' `: a simple simulator of Gaussian states and operations on continuous-variable circuit architectures. diff --git a/pennylane/ops/qutrit/channel.py b/pennylane/ops/qutrit/channel.py index c50572cf206..82f5d9c4407 100644 --- a/pennylane/ops/qutrit/channel.py +++ b/pennylane/ops/qutrit/channel.py @@ -235,8 +235,10 @@ class QutritAmplitudeDamping(Channel): K_0 = \begin{bmatrix} 1 & 0 & 0\\ 0 & \sqrt{1-\gamma_1} & 0 \\ - 0 & 0 & \sqrt{1-\gamma_2} - \end{bmatrix}, \quad + 0 & 0 & \sqrt{1-(\gamma_2+\gamma_3)} + \end{bmatrix} + + .. math:: K_1 = \begin{bmatrix} 0 & \sqrt{\gamma_1} & 0 \\ 0 & 0 & 0 \\ @@ -246,70 +248,95 @@ class QutritAmplitudeDamping(Channel): 0 & 0 & \sqrt{\gamma_2} \\ 0 & 0 & 0 \\ 0 & 0 & 0 + \end{bmatrix}, \quad + K_3 = \begin{bmatrix} + 0 & 0 & 0 \\ + 0 & 0 & \sqrt{\gamma_3} \\ + 0 & 0 & 0 \end{bmatrix} - where :math:`\gamma_1 \in [0, 1]` and :math:`\gamma_2 \in [0, 1]` are the amplitude damping - probabilities for subspaces (0,1) and (0,2) respectively. + where :math:`\gamma_1, \gamma_2, \gamma_3 \in [0, 1]` are the amplitude damping + probabilities for subspaces (0,1), (0,2), and (1,2) respectively. .. note:: - The Kraus operators :math:`\{K_0, K_1, K_2\}` are adapted from [`1 `_] (Eq. 8). + When :math:`\gamma_3=0` then Kraus operators :math:`\{K_0, K_1, K_2\}` are adapted from + [`1 `_] (Eq. 8). + + The Kraus operator :math:`K_3` represents the :math:`|2 \rangle \rightarrow |1 \rangle` transition which is more + likely on some devices [`2 `_] (Sec II.A). + + To maintain normalization :math:`\gamma_2 + \gamma_3 \leq 1`. + **Details:** * Number of wires: 1 - * Number of parameters: 2 + * Number of parameters: 3 Args: gamma_1 (float): :math:`|1 \rangle \rightarrow |0 \rangle` amplitude damping probability. gamma_2 (float): :math:`|2 \rangle \rightarrow |0 \rangle` amplitude damping probability. + gamma_3 (float): :math:`|2 \rangle \rightarrow |1 \rangle` amplitude damping probability. wires (Sequence[int] or int): the wire the channel acts on id (str or None): String representing the operation (optional) """ - num_params = 2 + num_params = 3 num_wires = 1 grad_method = "F" - def __init__(self, gamma_1, gamma_2, wires, id=None): - # Verify gamma_1 and gamma_2 - for gamma in (gamma_1, gamma_2): - if not (math.is_abstract(gamma_1) or math.is_abstract(gamma_2)): + def __init__(self, gamma_1, gamma_2, gamma_3, wires, id=None): + # Verify input + for gamma in (gamma_1, gamma_2, gamma_3): + if not math.is_abstract(gamma): if not 0.0 <= gamma <= 1.0: raise ValueError("Each probability must be in the interval [0,1]") - super().__init__(gamma_1, gamma_2, wires=wires, id=id) + if not (math.is_abstract(gamma_2) or math.is_abstract(gamma_3)): + if not 0.0 <= gamma_2 + gamma_3 <= 1.0: + raise ValueError(r"\gamma_2+\gamma_3 must be in the interval [0,1]") + super().__init__(gamma_1, gamma_2, gamma_3, wires=wires, id=id) @staticmethod - def compute_kraus_matrices(gamma_1, gamma_2): # pylint:disable=arguments-differ + def compute_kraus_matrices(gamma_1, gamma_2, gamma_3): # pylint:disable=arguments-differ r"""Kraus matrices representing the ``QutritAmplitudeDamping`` channel. Args: gamma_1 (float): :math:`|1\rangle \rightarrow |0\rangle` amplitude damping probability. gamma_2 (float): :math:`|2\rangle \rightarrow |0\rangle` amplitude damping probability. + gamma_3 (float): :math:`|2\rangle \rightarrow |1\rangle` amplitude damping probability. Returns: list(array): list of Kraus matrices **Example** - >>> qml.QutritAmplitudeDamping.compute_kraus_matrices(0.5, 0.25) + >>> qml.QutritAmplitudeDamping.compute_kraus_matrices(0.5, 0.25, 0.36) [ array([ [1. , 0. , 0. ], [0. , 0.70710678, 0. ], - [0. , 0. , 0.8660254 ]]), + [0. , 0. , 0.6244998 ]]), array([ [0. , 0.70710678, 0. ], [0. , 0. , 0. ], [0. , 0. , 0. ]]), array([ [0. , 0. , 0.5 ], [0. , 0. , 0. ], [0. , 0. , 0. ]]) + array([ [0. , 0. , 0. ], + [0. , 0. , 0.6 ], + [0. , 0. , 0. ]]) ] """ - K0 = math.diag([1, math.sqrt(1 - gamma_1 + math.eps), math.sqrt(1 - gamma_2 + math.eps)]) + K0 = math.diag( + [1, math.sqrt(1 - gamma_1 + math.eps), math.sqrt(1 - gamma_2 - gamma_3 + math.eps)] + ) K1 = math.sqrt(gamma_1 + math.eps) * math.convert_like( math.cast_like(math.array([[0, 1, 0], [0, 0, 0], [0, 0, 0]]), gamma_1), gamma_1 ) K2 = math.sqrt(gamma_2 + math.eps) * math.convert_like( math.cast_like(math.array([[0, 0, 1], [0, 0, 0], [0, 0, 0]]), gamma_2), gamma_2 ) - return [K0, K1, K2] + K3 = math.sqrt(gamma_3 + math.eps) * math.convert_like( + math.cast_like(math.array([[0, 0, 0], [0, 0, 1], [0, 0, 0]]), gamma_3), gamma_3 + ) + return [K0, K1, K2, K3] diff --git a/tests/devices/qutrit_mixed/test_qutrit_mixed_preprocessing.py b/tests/devices/qutrit_mixed/test_qutrit_mixed_preprocessing.py index 9574fdbec25..0d80adf2943 100644 --- a/tests/devices/qutrit_mixed/test_qutrit_mixed_preprocessing.py +++ b/tests/devices/qutrit_mixed/test_qutrit_mixed_preprocessing.py @@ -126,6 +126,7 @@ def test_measurement_is_swapped_out(self, mp_fn, mp_cls, shots): (qml.Snapshot(), True), (qml.TRX(1.1, 0), True), (qml.QutritDepolarizingChannel(0.4, 0), True), + (qml.QutritAmplitudeDamping(0.1, 0.2, 0.12, 0), True), ], ) def test_accepted_operator(self, op, expected): diff --git a/tests/ops/qutrit/test_qutrit_channel_ops.py b/tests/ops/qutrit/test_qutrit_channel_ops.py index 585f623f499..4b2612815ce 100644 --- a/tests/ops/qutrit/test_qutrit_channel_ops.py +++ b/tests/ops/qutrit/test_qutrit_channel_ops.py @@ -176,15 +176,15 @@ class TestQutritAmplitudeDamping: def test_gamma_zero(self, tol): """Test gamma_1=gamma_2=0 gives correct Kraus matrices""" - kraus_mats = qml.QutritAmplitudeDamping(0, 0, wires=0).kraus_matrices() + kraus_mats = qml.QutritAmplitudeDamping(0, 0, 0, wires=0).kraus_matrices() assert np.allclose(kraus_mats[0], np.eye(3), atol=tol, rtol=0) - assert np.allclose(kraus_mats[1], np.zeros((3, 3)), atol=tol, rtol=0) - assert np.allclose(kraus_mats[2], np.zeros((3, 3)), atol=tol, rtol=0) + for kraus_mat in kraus_mats[1:]: + assert np.allclose(kraus_mat, np.zeros((3, 3)), atol=tol, rtol=0) - @pytest.mark.parametrize("gamma1,gamma2", ((0.1, 0.2), (0.75, 0.75))) - def test_gamma_arbitrary(self, gamma1, gamma2, tol): - """Test the correct Kraus matrices are returned, also ensures that the sum of gammas can be over 1.""" - K_0 = np.diag((1, np.sqrt(1 - gamma1), np.sqrt(1 - gamma2))) + @pytest.mark.parametrize("gamma1,gamma2,gamma3", ((0.1, 0.2, 0.3), (0.75, 0.75, 0.25))) + def test_gamma_arbitrary(self, gamma1, gamma2, gamma3, tol): + """Test the correct Kraus matrices are returned.""" + K_0 = np.diag((1, np.sqrt(1 - gamma1), np.sqrt(1 - gamma2 - gamma3))) K_1 = np.zeros((3, 3)) K_1[0, 1] = np.sqrt(gamma1) @@ -192,33 +192,48 @@ def test_gamma_arbitrary(self, gamma1, gamma2, tol): K_2 = np.zeros((3, 3)) K_2[0, 2] = np.sqrt(gamma2) - expected = [K_0, K_1, K_2] - damping_channel = qml.QutritAmplitudeDamping(gamma1, gamma2, wires=0) + K_3 = np.zeros((3, 3)) + K_3[1, 2] = np.sqrt(gamma3) + + expected = [K_0, K_1, K_2, K_3] + damping_channel = qml.QutritAmplitudeDamping(gamma1, gamma2, gamma3, wires=0) assert np.allclose(damping_channel.kraus_matrices(), expected, atol=tol, rtol=0) - @pytest.mark.parametrize("gamma1,gamma2", ((1.5, 0.0), (0.0, 1.0 + math.eps))) - def test_gamma_invalid_parameter(self, gamma1, gamma2): - """Ensures that error is thrown when gamma_1 or gamma_2 are outside [0,1]""" - with pytest.raises(ValueError, match="Each probability must be in the interval"): - channel.QutritAmplitudeDamping(gamma1, gamma2, wires=0).kraus_matrices() + @pytest.mark.parametrize( + "gamma1,gamma2,gamma3", + ( + (1.5, 0.0, 0.0), + (0.0, 1.0 + math.eps, 0.0), + (0.0, 0.0, 1.1), + (0.0, 0.33, 0.67 + math.eps), + ), + ) + def test_gamma_invalid_parameter(self, gamma1, gamma2, gamma3): + """Ensures that error is thrown when gamma_1, gamma_2, gamma_3, or (gamma_2 + gamma_3) are outside [0,1]""" + with pytest.raises(ValueError, match="must be in the interval"): + channel.QutritAmplitudeDamping(gamma1, gamma2, gamma3, wires=0).kraus_matrices() @staticmethod - def expected_jac_fn(gamma_1, gamma_2): + def expected_jac_fn(gamma_1, gamma_2, gamma_3): """Gets the expected Jacobian of Kraus matrices""" - partial_1 = [math.zeros((3, 3)) for _ in range(3)] + partial_1 = [math.zeros((3, 3)) for _ in range(4)] partial_1[0][1, 1] = -1 / (2 * math.sqrt(1 - gamma_1)) partial_1[1][0, 1] = 1 / (2 * math.sqrt(gamma_1)) - partial_2 = [math.zeros((3, 3)) for _ in range(3)] - partial_2[0][2, 2] = -1 / (2 * math.sqrt(1 - gamma_2)) + partial_2 = [math.zeros((3, 3)) for _ in range(4)] + partial_2[0][2, 2] = -1 / (2 * math.sqrt(1 - gamma_2 - gamma_3)) partial_2[2][0, 2] = 1 / (2 * math.sqrt(gamma_2)) - return [partial_1, partial_2] + partial_3 = [math.zeros((3, 3)) for _ in range(4)] + partial_3[0][2, 2] = -1 / (2 * math.sqrt(1 - gamma_2 - gamma_3)) + partial_3[3][1, 2] = 1 / (2 * math.sqrt(gamma_3)) + + return [partial_1, partial_2, partial_3] @staticmethod - def kraus_fn(gamma_1, gamma_2): + def kraus_fn(gamma_1, gamma_2, gamma_3): """Gets the Kraus matrices of QutritAmplitudeDamping channel, used for differentiation.""" - damping_channel = qml.QutritAmplitudeDamping(gamma_1, gamma_2, wires=0) + damping_channel = qml.QutritAmplitudeDamping(gamma_1, gamma_2, gamma_3, wires=0) return math.stack(damping_channel.kraus_matrices()) @pytest.mark.autograd @@ -226,8 +241,10 @@ def test_kraus_jac_autograd(self): """Tests Jacobian of Kraus matrices using autograd.""" gamma_1 = pnp.array(0.43, requires_grad=True) gamma_2 = pnp.array(0.12, requires_grad=True) - jac = qml.jacobian(self.kraus_fn)(gamma_1, gamma_2) - assert math.allclose(jac, self.expected_jac_fn(gamma_1, gamma_2)) + gamma_3 = pnp.array(0.35, requires_grad=True) + + jac = qml.jacobian(self.kraus_fn)(gamma_1, gamma_2, gamma_3) + assert math.allclose(jac, self.expected_jac_fn(gamma_1, gamma_2, gamma_3)) @pytest.mark.torch def test_kraus_jac_torch(self): @@ -236,11 +253,15 @@ def test_kraus_jac_torch(self): gamma_1 = torch.tensor(0.43, requires_grad=True) gamma_2 = torch.tensor(0.12, requires_grad=True) + gamma_3 = torch.tensor(0.35, requires_grad=True) - jac = torch.autograd.functional.jacobian(self.kraus_fn, (gamma_1, gamma_2)) - expected = self.expected_jac_fn(gamma_1.detach().numpy(), gamma_2.detach().numpy()) - assert math.allclose(jac[0].detach().numpy(), expected[0]) - assert math.allclose(jac[1].detach().numpy(), expected[1]) + jac = torch.autograd.functional.jacobian(self.kraus_fn, (gamma_1, gamma_2, gamma_3)) + expected = self.expected_jac_fn( + gamma_1.detach().numpy(), gamma_2.detach().numpy(), gamma_3.detach().numpy() + ) + + for res_partial, exp_partial in zip(jac, expected): + assert math.allclose(res_partial.detach().numpy(), exp_partial) @pytest.mark.tf def test_kraus_jac_tf(self): @@ -249,10 +270,12 @@ def test_kraus_jac_tf(self): gamma_1 = tf.Variable(0.43) gamma_2 = tf.Variable(0.12) + gamma_3 = tf.Variable(0.35) + with tf.GradientTape() as tape: - out = self.kraus_fn(gamma_1, gamma_2) - jac = tape.jacobian(out, (gamma_1, gamma_2)) - assert math.allclose(jac, self.expected_jac_fn(gamma_1, gamma_2)) + out = self.kraus_fn(gamma_1, gamma_2, gamma_3) + jac = tape.jacobian(out, (gamma_1, gamma_2, gamma_3)) + assert math.allclose(jac, self.expected_jac_fn(gamma_1, gamma_2, gamma_3)) @pytest.mark.jax def test_kraus_jac_jax(self): @@ -261,5 +284,7 @@ def test_kraus_jac_jax(self): gamma_1 = jax.numpy.array(0.43) gamma_2 = jax.numpy.array(0.12) - jac = jax.jacobian(self.kraus_fn, argnums=[0, 1])(gamma_1, gamma_2) - assert math.allclose(jac, self.expected_jac_fn(gamma_1, gamma_2)) + gamma_3 = jax.numpy.array(0.35) + + jac = jax.jacobian(self.kraus_fn, argnums=[0, 1, 2])(gamma_1, gamma_2, gamma_3) + assert math.allclose(jac, self.expected_jac_fn(gamma_1, gamma_2, gamma_3))