diff --git a/.github/stable/all_interfaces.txt b/.github/stable/all_interfaces.txt
index 77b45a0d419..05f6fb152b4 100644
--- a/.github/stable/all_interfaces.txt
+++ b/.github/stable/all_interfaces.txt
@@ -1,17 +1,18 @@
absl-py==2.1.0
appdirs==1.4.4
+astroid==2.6.6
astunparse==1.6.3
autograd==1.6.2
-autoray==0.6.9
+autoray==0.6.12
black==24.4.2
cachetools==5.3.3
-certifi==2024.2.2
+certifi==2024.6.2
cfgv==3.4.0
charset-normalizer==3.3.2
-clarabel==0.7.1
+clarabel==0.9.0
click==8.1.7
contourpy==1.2.1
-coverage==7.5.1
+coverage==7.5.3
cvxopt==1.3.2
cvxpy==1.5.1
cycler==0.12.1
@@ -22,12 +23,12 @@ execnet==2.1.1
filelock==3.14.0
flaky==3.8.1
flatbuffers==24.3.25
-fonttools==4.51.0
+fonttools==4.53.0
fsspec==2024.5.0
future==1.0.0
gast==0.5.4
google-pasta==0.2.0
-grpcio==1.63.0
+grpcio==1.64.0
h5py==3.11.0
identify==2.5.36
idna==3.7
@@ -40,22 +41,24 @@ jaxlib==0.4.23
Jinja2==3.1.4
keras==3.3.3
kiwisolver==1.4.5
+lazy-object-proxy==1.10.0
libclang==18.1.1
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.0
+mccabe==0.6.1
mdurl==0.1.2
ml-dtypes==0.3.2
mpmath==1.3.0
mypy-extensions==1.0.0
namex==0.0.8
networkx==3.2.1
-nodeenv==1.8.0
+nodeenv==1.9.0
numpy==1.26.4
opt-einsum==3.3.0
optree==0.11.0
-osqp==0.6.5
+osqp==0.6.7
packaging==24.0
pathspec==0.12.1
PennyLane_Lightning==0.37.0
@@ -67,8 +70,9 @@ protobuf==4.25.3
py==1.11.0
py-cpuinfo==9.0.0
Pygments==2.18.0
+pylint==2.7.4
pyparsing==3.1.2
-pytest==8.2.0
+pytest==8.2.1
pytest-benchmark==4.0.0
pytest-cov==5.0.0
pytest-forked==1.6.0
@@ -78,14 +82,14 @@ python-dateutil==2.9.0.post0
pytorch-triton-rocm==2.3.0
PyYAML==6.0.1
qdldl==0.1.7.post2
-requests==2.31.0
+requests==2.32.3
rich==13.7.1
rustworkx==0.14.2
-scipy==1.11.4
-scs==3.2.4.post1
+scipy==1.12.0
+scs==3.2.4.post2
semantic-version==2.10.0
six==1.16.0
-sympy==1.12
+sympy==1.12.1
tensorboard==2.16.2
tensorboard-data-server==0.7.2
tensorflow==2.16.1
@@ -95,9 +99,9 @@ tf_keras==2.16.0
toml==0.10.2
tomli==2.0.1
torch==2.3.0+rocm6.0
-typing_extensions==4.11.0
+typing_extensions==4.12.1
urllib3==2.2.1
virtualenv==20.26.2
Werkzeug==3.0.3
-wrapt==1.16.0
-zipp==3.18.2
+wrapt==1.12.1
+zipp==3.19.1
diff --git a/.github/stable/core.txt b/.github/stable/core.txt
index def68dbf5b7..2c9830a6d42 100644
--- a/.github/stable/core.txt
+++ b/.github/stable/core.txt
@@ -1,15 +1,16 @@
appdirs==1.4.4
+astroid==2.6.6
autograd==1.6.2
-autoray==0.6.9
+autoray==0.6.12
black==24.4.2
cachetools==5.3.3
-certifi==2024.2.2
+certifi==2024.6.2
cfgv==3.4.0
charset-normalizer==3.3.2
-clarabel==0.7.1
+clarabel==0.9.0
click==8.1.7
contourpy==1.2.1
-coverage==7.5.1
+coverage==7.5.3
cvxopt==1.3.2
cvxpy==1.5.1
cycler==0.12.1
@@ -19,7 +20,7 @@ exceptiongroup==1.2.1
execnet==2.1.1
filelock==3.14.0
flaky==3.8.1
-fonttools==4.51.0
+fonttools==4.53.0
future==1.0.0
identify==2.5.36
idna==3.7
@@ -27,12 +28,14 @@ importlib_resources==6.4.0
iniconfig==2.0.0
isort==5.13.2
kiwisolver==1.4.5
+lazy-object-proxy==1.10.0
matplotlib==3.9.0
+mccabe==0.6.1
mypy-extensions==1.0.0
networkx==3.2.1
-nodeenv==1.8.0
+nodeenv==1.9.0
numpy==1.26.4
-osqp==0.6.5
+osqp==0.6.7
packaging==24.0
pathspec==0.12.1
PennyLane_Lightning==0.37.0
@@ -42,8 +45,9 @@ pluggy==1.5.0
pre-commit==3.7.1
py==1.11.0
py-cpuinfo==9.0.0
+pylint==2.7.4
pyparsing==3.1.2
-pytest==8.2.0
+pytest==8.2.1
pytest-benchmark==4.0.0
pytest-cov==5.0.0
pytest-forked==1.6.0
@@ -53,15 +57,16 @@ pytest-xdist==3.6.1
python-dateutil==2.9.0.post0
PyYAML==6.0.1
qdldl==0.1.7.post2
-requests==2.31.0
+requests==2.32.3
rustworkx==0.14.2
scipy==1.11.4
-scs==3.2.4.post1
+scs==3.2.4.post2
semantic-version==2.10.0
six==1.16.0
toml==0.10.2
tomli==2.0.1
-typing_extensions==4.11.0
+typing_extensions==4.12.1
urllib3==2.2.1
virtualenv==20.26.2
-zipp==3.18.2
+wrapt==1.12.1
+zipp==3.19.1
diff --git a/.github/stable/doc.txt b/.github/stable/doc.txt
index a7492e17485..e4c039b43e6 100644
--- a/.github/stable/doc.txt
+++ b/.github/stable/doc.txt
@@ -10,7 +10,7 @@ autograd==1.6.2
autoray==0.6.12
Babel==2.15.0
cachetools==5.3.3
-certifi==2024.2.2
+certifi==2024.6.2
charset-normalizer==3.3.2
cirq-core==1.3.0
contourpy==1.2.1
@@ -20,7 +20,7 @@ docutils==0.16
duet==0.2.9
exceptiongroup==1.2.1
flatbuffers==24.3.25
-fonttools==4.51.0
+fonttools==4.53.0
frozenlist==1.4.1
fsspec==2024.5.0
future==1.0.0
@@ -29,7 +29,7 @@ google-auth==2.29.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
graphviz==0.20.3
-grpcio==1.63.0
+grpcio==1.64.0
h5py==3.11.0
idna==3.7
imagesize==1.4.1
@@ -72,8 +72,8 @@ pybtex-docutils==1.0.3
Pygments==2.18.0
pygments-github-lexers==0.0.5
pyparsing==3.1.2
-pyscf==2.5.0
-pytest==8.2.0
+pyscf==2.6.0
+pytest==8.2.1
python-dateutil==2.9.0.post0
pytz==2024.1
PyYAML==6.0.1
@@ -81,7 +81,7 @@ requests==2.28.2
requests-oauthlib==2.0.0
rsa==4.9
rustworkx==0.12.1
-scipy==1.13.0
+scipy==1.13.1
semantic-version==2.10.0
six==1.16.0
snowballstemmer==2.2.0
@@ -98,7 +98,7 @@ sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
sphinxext-opengraph==0.6.3
-sympy==1.12
+sympy==1.12.1
tensorboard==2.11.2
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
@@ -111,11 +111,11 @@ toml==0.10.2
tomli==2.0.1
torch==1.9.0+cpu
tqdm==4.66.4
-typing_extensions==4.11.0
+typing_extensions==4.12.1
tzdata==2024.1
urllib3==1.26.18
Werkzeug==3.0.3
wrapt==1.16.0
xanadu-sphinx-theme==0.5.0
yarl==1.9.4
-zipp==3.18.2
+zipp==3.19.1
diff --git a/.github/stable/external.txt b/.github/stable/external.txt
index febc739d8c6..89e5d1cf522 100644
--- a/.github/stable/external.txt
+++ b/.github/stable/external.txt
@@ -1,32 +1,35 @@
absl-py==2.1.0
-anyio==4.3.0
+anyio==4.4.0
appdirs==1.4.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
+astroid==2.6.6
asttokens==2.4.1
astunparse==1.6.3
async-lru==2.0.4
attrs==23.2.0
autograd==1.6.2
-autoray==0.6.9
+autoray==0.6.12
Babel==2.15.0
beautifulsoup4==4.12.3
black==24.4.2
bleach==6.1.0
cachetools==5.3.3
-certifi==2024.2.2
+certifi==2024.6.2
cffi==1.16.0
cfgv==3.4.0
charset-normalizer==3.3.2
-clarabel==0.7.1
+clarabel==0.9.0
click==8.1.7
comm==0.2.2
contourpy==1.2.1
-coverage==7.5.1
+cotengra==0.6.2
+coverage==7.5.3
cvxopt==1.3.2
cvxpy==1.5.1
cycler==0.12.1
+cytoolz==0.12.3
debugpy==1.8.1
decorator==5.1.1
defusedxml==0.7.1
@@ -40,12 +43,12 @@ fastjsonschema==2.19.1
filelock==3.14.0
flaky==3.8.1
flatbuffers==24.3.25
-fonttools==4.51.0
+fonttools==4.53.0
fqdn==1.5.1
future==1.0.0
gast==0.5.4
google-pasta==0.2.0
-grpcio==1.63.0
+grpcio==1.64.0
h11==0.14.0
h5py==3.11.0
httpcore==1.0.5
@@ -71,23 +74,26 @@ jsonschema==4.22.0
jsonschema-specifications==2023.12.1
jupyter-events==0.10.0
jupyter-lsp==2.2.5
-jupyter_client==8.6.1
+jupyter_client==8.6.2
jupyter_core==5.7.2
-jupyter_server==2.14.0
+jupyter_server==2.14.1
jupyter_server_terminals==0.5.3
-jupyterlab==4.1.8
+jupyterlab==4.2.1
jupyterlab-widgets==1.1.7
jupyterlab_pygments==0.3.0
-jupyterlab_server==2.27.1
+jupyterlab_server==2.27.2
keras==3.3.3
kiwisolver==1.4.5
lark==1.1.9
+lazy-object-proxy==1.10.0
libclang==18.1.1
+llvmlite==0.42.0
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.0
matplotlib-inline==0.1.7
+mccabe==0.6.1
mdurl==0.1.2
mistune==3.0.2
ml-dtypes==0.3.2
@@ -98,13 +104,14 @@ nbconvert==7.16.4
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.2.1
-nodeenv==1.8.0
-notebook==7.1.3
+nodeenv==1.9.0
+notebook==7.2.0
notebook_shim==0.2.4
+numba==0.59.1
numpy==1.26.4
opt-einsum==3.3.0
optree==0.11.0
-osqp==0.6.5
+osqp==0.6.7
overrides==7.7.0
packaging==24.0
pandocfilters==1.5.1
@@ -118,7 +125,7 @@ platformdirs==4.2.2
pluggy==1.5.0
pre-commit==3.7.1
prometheus_client==0.20.0
-prompt-toolkit==3.0.43
+prompt_toolkit==3.0.45
protobuf==4.25.3
psutil==5.9.8
ptyprocess==0.7.0
@@ -127,9 +134,10 @@ py==1.11.0
py-cpuinfo==9.0.0
pycparser==2.22
Pygments==2.18.0
+pylint==2.7.4
pyparsing==3.1.2
pyperclip==1.8.2
-pytest==8.2.0
+pytest==8.2.1
pytest-benchmark==4.0.0
pytest-cov==5.0.0
pytest-forked==1.6.0
@@ -141,15 +149,16 @@ PyYAML==6.0.1
pyzmq==26.0.3
pyzx==0.8.0
qdldl==0.1.7.post2
+quimb==1.8.1
referencing==0.35.1
-requests==2.31.0
+requests==2.32.3
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.7.1
rpds-py==0.18.1
rustworkx==0.14.2
-scipy==1.11.4
-scs==3.2.4.post1
+scipy==1.12.0
+scs==3.2.4.post2
semantic-version==2.10.0
Send2Trash==1.8.3
six==1.16.0
@@ -168,11 +177,12 @@ tinycss2==1.3.0
toml==0.10.2
tomli==2.0.1
tomlkit==0.12.5
+toolz==0.12.1
tornado==6.4
tqdm==4.66.4
traitlets==5.14.3
types-python-dateutil==2.9.0.20240316
-typing_extensions==4.11.0
+typing_extensions==4.12.1
uri-template==1.3.0
urllib3==2.2.1
virtualenv==20.26.2
@@ -182,5 +192,5 @@ webencodings==0.5.1
websocket-client==1.8.0
Werkzeug==3.0.3
widgetsnbextension==3.6.6
-wrapt==1.16.0
-zipp==3.18.2
+wrapt==1.12.1
+zipp==3.19.1
diff --git a/.github/stable/jax.txt b/.github/stable/jax.txt
index d3181d8b927..4978dada90c 100644
--- a/.github/stable/jax.txt
+++ b/.github/stable/jax.txt
@@ -1,15 +1,16 @@
appdirs==1.4.4
+astroid==2.6.6
autograd==1.6.2
-autoray==0.6.9
+autoray==0.6.12
black==24.4.2
cachetools==5.3.3
-certifi==2024.2.2
+certifi==2024.6.2
cfgv==3.4.0
charset-normalizer==3.3.2
-clarabel==0.7.1
+clarabel==0.9.0
click==8.1.7
contourpy==1.2.1
-coverage==7.5.1
+coverage==7.5.3
cvxopt==1.3.2
cvxpy==1.5.1
cycler==0.12.1
@@ -19,7 +20,7 @@ exceptiongroup==1.2.1
execnet==2.1.1
filelock==3.14.0
flaky==3.8.1
-fonttools==4.51.0
+fonttools==4.53.0
future==1.0.0
identify==2.5.36
idna==3.7
@@ -30,14 +31,16 @@ isort==5.13.2
jax==0.4.23
jaxlib==0.4.23
kiwisolver==1.4.5
+lazy-object-proxy==1.10.0
matplotlib==3.9.0
+mccabe==0.6.1
ml-dtypes==0.4.0
mypy-extensions==1.0.0
networkx==3.2.1
-nodeenv==1.8.0
+nodeenv==1.9.0
numpy==1.26.4
opt-einsum==3.3.0
-osqp==0.6.5
+osqp==0.6.7
packaging==24.0
pathspec==0.12.1
PennyLane_Lightning==0.37.0
@@ -47,8 +50,9 @@ pluggy==1.5.0
pre-commit==3.7.1
py==1.11.0
py-cpuinfo==9.0.0
+pylint==2.7.4
pyparsing==3.1.2
-pytest==8.2.0
+pytest==8.2.1
pytest-benchmark==4.0.0
pytest-cov==5.0.0
pytest-forked==1.6.0
@@ -58,15 +62,16 @@ pytest-xdist==3.6.1
python-dateutil==2.9.0.post0
PyYAML==6.0.1
qdldl==0.1.7.post2
-requests==2.31.0
+requests==2.32.3
rustworkx==0.14.2
-scipy==1.11.4
-scs==3.2.4.post1
+scipy==1.12.0
+scs==3.2.4.post2
semantic-version==2.10.0
six==1.16.0
toml==0.10.2
tomli==2.0.1
-typing_extensions==4.11.0
+typing_extensions==4.12.1
urllib3==2.2.1
virtualenv==20.26.2
-zipp==3.18.2
+wrapt==1.12.1
+zipp==3.19.1
diff --git a/.github/stable/tf.txt b/.github/stable/tf.txt
index 41cdf783405..1bf9daf1825 100644
--- a/.github/stable/tf.txt
+++ b/.github/stable/tf.txt
@@ -1,17 +1,18 @@
absl-py==2.1.0
appdirs==1.4.4
+astroid==2.6.6
astunparse==1.6.3
autograd==1.6.2
-autoray==0.6.9
+autoray==0.6.12
black==24.4.2
cachetools==5.3.3
-certifi==2024.2.2
+certifi==2024.6.2
cfgv==3.4.0
charset-normalizer==3.3.2
-clarabel==0.7.1
+clarabel==0.9.0
click==8.1.7
contourpy==1.2.1
-coverage==7.5.1
+coverage==7.5.3
cvxopt==1.3.2
cvxpy==1.5.1
cycler==0.12.1
@@ -22,11 +23,11 @@ execnet==2.1.1
filelock==3.14.0
flaky==3.8.1
flatbuffers==24.3.25
-fonttools==4.51.0
+fonttools==4.53.0
future==1.0.0
gast==0.5.4
google-pasta==0.2.0
-grpcio==1.63.0
+grpcio==1.64.0
h5py==3.11.0
identify==2.5.36
idna==3.7
@@ -36,21 +37,23 @@ iniconfig==2.0.0
isort==5.13.2
keras==3.3.3
kiwisolver==1.4.5
+lazy-object-proxy==1.10.0
libclang==18.1.1
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.0
+mccabe==0.6.1
mdurl==0.1.2
ml-dtypes==0.3.2
mypy-extensions==1.0.0
namex==0.0.8
networkx==3.2.1
-nodeenv==1.8.0
+nodeenv==1.9.0
numpy==1.26.4
opt-einsum==3.3.0
optree==0.11.0
-osqp==0.6.5
+osqp==0.6.7
packaging==24.0
pathspec==0.12.1
PennyLane_Lightning==0.37.0
@@ -62,8 +65,9 @@ protobuf==4.25.3
py==1.11.0
py-cpuinfo==9.0.0
Pygments==2.18.0
+pylint==2.7.4
pyparsing==3.1.2
-pytest==8.2.0
+pytest==8.2.1
pytest-benchmark==4.0.0
pytest-cov==5.0.0
pytest-forked==1.6.0
@@ -73,11 +77,11 @@ pytest-xdist==3.6.1
python-dateutil==2.9.0.post0
PyYAML==6.0.1
qdldl==0.1.7.post2
-requests==2.31.0
+requests==2.32.3
rich==13.7.1
rustworkx==0.14.2
scipy==1.11.4
-scs==3.2.4.post1
+scs==3.2.4.post2
semantic-version==2.10.0
six==1.16.0
tensorboard==2.16.2
@@ -88,9 +92,9 @@ termcolor==2.4.0
tf_keras==2.16.0
toml==0.10.2
tomli==2.0.1
-typing_extensions==4.11.0
+typing_extensions==4.12.1
urllib3==2.2.1
virtualenv==20.26.2
Werkzeug==3.0.3
-wrapt==1.16.0
-zipp==3.18.2
+wrapt==1.12.1
+zipp==3.19.1
diff --git a/.github/stable/torch.txt b/.github/stable/torch.txt
index 1acf7336c96..18a7db0f881 100644
--- a/.github/stable/torch.txt
+++ b/.github/stable/torch.txt
@@ -1,15 +1,16 @@
appdirs==1.4.4
+astroid==2.6.6
autograd==1.6.2
-autoray==0.6.9
+autoray==0.6.12
black==24.4.2
cachetools==5.3.3
-certifi==2024.2.2
+certifi==2024.6.2
cfgv==3.4.0
charset-normalizer==3.3.2
-clarabel==0.7.1
+clarabel==0.9.0
click==8.1.7
contourpy==1.2.1
-coverage==7.5.1
+coverage==7.5.3
cvxopt==1.3.2
cvxpy==1.5.1
cycler==0.12.1
@@ -19,7 +20,7 @@ exceptiongroup==1.2.1
execnet==2.1.1
filelock==3.14.0
flaky==3.8.1
-fonttools==4.51.0
+fonttools==4.53.0
fsspec==2024.5.0
future==1.0.0
identify==2.5.36
@@ -29,14 +30,16 @@ iniconfig==2.0.0
isort==5.13.2
Jinja2==3.1.4
kiwisolver==1.4.5
+lazy-object-proxy==1.10.0
MarkupSafe==2.1.5
matplotlib==3.9.0
+mccabe==0.6.1
mpmath==1.3.0
mypy-extensions==1.0.0
networkx==3.2.1
-nodeenv==1.8.0
+nodeenv==1.9.0
numpy==1.26.4
-osqp==0.6.5
+osqp==0.6.7
packaging==24.0
pathspec==0.12.1
PennyLane_Lightning==0.37.0
@@ -46,8 +49,9 @@ pluggy==1.5.0
pre-commit==3.7.1
py==1.11.0
py-cpuinfo==9.0.0
+pylint==2.7.4
pyparsing==3.1.2
-pytest==8.2.0
+pytest==8.2.1
pytest-benchmark==4.0.0
pytest-cov==5.0.0
pytest-forked==1.6.0
@@ -57,17 +61,18 @@ python-dateutil==2.9.0.post0
pytorch-triton-rocm==2.3.0
PyYAML==6.0.1
qdldl==0.1.7.post2
-requests==2.31.0
+requests==2.32.3
rustworkx==0.14.2
scipy==1.11.4
-scs==3.2.4.post1
+scs==3.2.4.post2
semantic-version==2.10.0
six==1.16.0
-sympy==1.12
+sympy==1.12.1
toml==0.10.2
tomli==2.0.1
torch==2.3.0+rocm6.0
-typing_extensions==4.11.0
+typing_extensions==4.12.1
urllib3==2.2.1
virtualenv==20.26.2
-zipp==3.18.2
+wrapt==1.12.1
+zipp==3.19.1
diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 3d90194d325..6d198b3b145 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -108,9 +108,10 @@
`m = measure(0); qml.sample(m)`.
[(#5673)](https://github.com/PennyLaneAI/pennylane/pull/5673)
-* PennyLane operators and measurements can now automatically be captured as instructions in JAXPR.
+* PennyLane operators, measurements, and QNodes can now automatically be captured as instructions in JAXPR.
[(#5564)](https://github.com/PennyLaneAI/pennylane/pull/5564)
[(#5511)](https://github.com/PennyLaneAI/pennylane/pull/5511)
+ [(#5708)](https://github.com/PennyLaneAI/pennylane/pull/5708)
[(#5523)](https://github.com/PennyLaneAI/pennylane/pull/5523)
* The `decompose` transform has an `error` kwarg to specify the type of error that should be raised,
@@ -151,6 +152,10 @@
* ``qml.QutritDepolarizingChannel`` has been added, allowing for depolarizing noise to be simulated on the `default.qutrit.mixed` device.
[(#5502)](https://github.com/PennyLaneAI/pennylane/pull/5502)
+* `qml.QutritAmplitudeDamping` channel has been added, allowing for noise processes modelled by amplitude damping to be simulated on the `default.qutrit.mixed` device.
+ [(#5503)](https://github.com/PennyLaneAI/pennylane/pull/5503)
+ [(#5757)](https://github.com/PennyLaneAI/pennylane/pull/5757)
+
Breaking changes 💔
* A custom decomposition can no longer be provided to `QDrift`. Instead, apply the operations in your custom
@@ -176,9 +181,6 @@
* `Controlled.wires` does not include `self.work_wires` anymore. That can be accessed separately through `Controlled.work_wires`.
Consequently, `Controlled.active_wires` has been removed in favour of the more common `Controlled.wires`.
[(#5728)](https://github.com/PennyLaneAI/pennylane/pull/5728)
-
-* `qml.QutritAmplitudeDamping` channel has been added, allowing for noise processes modelled by amplitude damping to be simulated on the `default.qutrit.mixed` device.
- [(#5503)](https://github.com/PennyLaneAI/pennylane/pull/5503)
Deprecations 👋
diff --git a/pennylane/__init__.py b/pennylane/__init__.py
index 80419832ae0..139ebc29222 100644
--- a/pennylane/__init__.py
+++ b/pennylane/__init__.py
@@ -218,6 +218,9 @@ def device(name, *args, **kwargs):
* :mod:`'default.qutrit' `: a simple
state simulator of qutrit-based quantum circuit architectures.
+ * :mod:`'default.qutrit.mixed' `: a
+ mixed-state simulator of qutrit-based quantum circuit architectures.
+
* :mod:`'default.gaussian' `: a simple simulator
of Gaussian states and operations on continuous-variable circuit architectures.
diff --git a/pennylane/capture/__init__.py b/pennylane/capture/__init__.py
index 33edf07dbe6..3cbb39a3dfa 100644
--- a/pennylane/capture/__init__.py
+++ b/pennylane/capture/__init__.py
@@ -33,6 +33,7 @@
~create_measurement_obs_primitive
~create_measurement_wires_primitive
~create_measurement_mcm_primitive
+ ~qnode_call
To activate and deactivate the new PennyLane program capturing mechanism, use
the switches ``qml.capture.enable`` and ``qml.capture.disable``.
@@ -132,6 +133,7 @@ def _(*args, **kwargs):
create_measurement_wires_primitive,
create_measurement_mcm_primitive,
)
+from .capture_qnode import qnode_call
def __getattr__(key):
diff --git a/pennylane/capture/capture_qnode.py b/pennylane/capture/capture_qnode.py
new file mode 100644
index 00000000000..bbcd731e934
--- /dev/null
+++ b/pennylane/capture/capture_qnode.py
@@ -0,0 +1,167 @@
+# Copyright 2024 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This submodule defines a capture compatible call to QNodes.
+"""
+
+from functools import lru_cache, partial
+
+import pennylane as qml
+
+has_jax = True
+try:
+ import jax
+except ImportError:
+ has_jax = False
+
+
+def _get_shapes_for(*measurements, shots=None, num_device_wires=0):
+ if jax.config.jax_enable_x64:
+ dtype_map = {
+ float: jax.numpy.float64,
+ int: jax.numpy.int64,
+ complex: jax.numpy.complex128,
+ }
+ else:
+ dtype_map = {
+ float: jax.numpy.float32,
+ int: jax.numpy.int32,
+ complex: jax.numpy.complex64,
+ }
+
+ shapes = []
+ if not shots:
+ shots = [None]
+
+ for s in shots:
+ for m in measurements:
+ shape, dtype = m.abstract_eval(shots=s, num_device_wires=num_device_wires)
+ shapes.append(jax.core.ShapedArray(shape, dtype_map.get(dtype, dtype)))
+ return shapes
+
+
+@lru_cache()
+def _get_qnode_prim():
+ if not has_jax:
+ return None
+ qnode_prim = jax.core.Primitive("qnode")
+ qnode_prim.multiple_results = True
+
+ @qnode_prim.def_impl
+ def _(*args, shots, device, qnode_kwargs, qfunc_jaxpr):
+ def qfunc(*inner_args):
+ return jax.core.eval_jaxpr(qfunc_jaxpr.jaxpr, qfunc_jaxpr.consts, *inner_args)
+
+ qnode = qml.QNode(qfunc, device, **qnode_kwargs)
+ return qnode._impl_call(*args, shots=shots) # pylint: disable=protected-access
+
+ # pylint: disable=unused-argument
+ @qnode_prim.def_abstract_eval
+ def _(*args, shots, device, qnode_kwargs, qfunc_jaxpr):
+ mps = qfunc_jaxpr.out_avals
+ return _get_shapes_for(*mps, shots=shots, num_device_wires=len(device.wires))
+
+ return qnode_prim
+
+
+# pylint: disable=protected-access
+def _get_device_shots(device) -> "qml.measurements.Shots":
+ if isinstance(device, qml.devices.LegacyDevice):
+ if device._shot_vector:
+ return qml.measurements.Shots(device._raw_shot_sequence)
+ return qml.measurements.Shots(device.shots)
+ return device.shots
+
+
+def qnode_call(qnode: "qml.QNode", *args, **kwargs) -> "qml.typing.Result":
+ """A capture compatible call to a QNode. This function is internally used by ``QNode.__call__``.
+
+ Args:
+ qnode (QNode): a QNode
+ args: the arguments the QNode is called with
+
+ Keyword Args:
+ Any keyword arguments accepted by the quantum function
+
+ Returns:
+ qml.typing.Result: the result of a qnode execution
+
+ **Example:**
+
+ .. code-block:: python
+
+ @qml.qnode(qml.device('lightning.qubit', wires=1))
+ def circuit(x):
+ qml.RX(x, wires=0)
+ return qml.expval(qml.Z(0)), qml.probs()
+
+ def f(x):
+ expval_z, probs = circuit(np.pi * x, shots=50)
+ return 2*expval_z + probs
+
+ jaxpr = jax.make_jaxpr(f)(0.1)
+ print("jaxpr:")
+ print(jaxpr)
+
+ res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.7)
+ print()
+ print("result:")
+ print(res)
+
+
+ .. code-block:: none
+
+ jaxpr:
+ { lambda ; a:f32[]. let
+ b:f32[] = mul 3.141592653589793 a
+ c:f32[] d:f32[2] = qnode[
+ device=
+ qfunc_jaxpr={ lambda ; e:f32[]. let
+ _:AbstractOperator() = RX[n_wires=1] e 0
+ f:AbstractOperator() = PauliZ[n_wires=1] 0
+ g:AbstractMeasurement(n_wires=None) = expval_obs f
+ h:AbstractMeasurement(n_wires=0) = probs_wires
+ in (g, h) }
+ qnode_kwargs={'diff_method': 'best', 'grad_on_execution': 'best', 'cache': False, 'cachesize': 10000, 'max_diff': 1, 'max_expansion': 10, 'device_vjp': False}
+ shots=Shots(total=50)
+ ] b
+ i:f32[] = mul 2.0 c
+ j:f32[2] = add i d
+ in (j,) }
+
+ result:
+ [Array([-1.3 , -0.74], dtype=float32)]
+
+
+ """
+ if "shots" in kwargs:
+ shots = qml.measurements.Shots(kwargs.pop("shots"))
+ else:
+ shots = _get_device_shots(qnode.device)
+ if shots.has_partitioned_shots:
+ # Questions over the pytrees and the nested result object shape
+ raise NotImplementedError("shot vectors are not yet supported with plxpr capture.")
+
+ if not qnode.device.wires:
+ raise NotImplementedError("devices must specify wires for integration with plxpr capture.")
+
+ qfunc = partial(qnode.func, **kwargs) if kwargs else qnode.func
+
+ qfunc_jaxpr = jax.make_jaxpr(qfunc)(*args)
+ qnode_kwargs = {"diff_method": qnode.diff_method, **qnode.execute_kwargs}
+ qnode_prim = _get_qnode_prim()
+
+ return qnode_prim.bind(
+ *args, shots=shots, device=qnode.device, qnode_kwargs=qnode_kwargs, qfunc_jaxpr=qfunc_jaxpr
+ )
diff --git a/pennylane/ops/qutrit/channel.py b/pennylane/ops/qutrit/channel.py
index c50572cf206..82f5d9c4407 100644
--- a/pennylane/ops/qutrit/channel.py
+++ b/pennylane/ops/qutrit/channel.py
@@ -235,8 +235,10 @@ class QutritAmplitudeDamping(Channel):
K_0 = \begin{bmatrix}
1 & 0 & 0\\
0 & \sqrt{1-\gamma_1} & 0 \\
- 0 & 0 & \sqrt{1-\gamma_2}
- \end{bmatrix}, \quad
+ 0 & 0 & \sqrt{1-(\gamma_2+\gamma_3)}
+ \end{bmatrix}
+
+ .. math::
K_1 = \begin{bmatrix}
0 & \sqrt{\gamma_1} & 0 \\
0 & 0 & 0 \\
@@ -246,70 +248,95 @@ class QutritAmplitudeDamping(Channel):
0 & 0 & \sqrt{\gamma_2} \\
0 & 0 & 0 \\
0 & 0 & 0
+ \end{bmatrix}, \quad
+ K_3 = \begin{bmatrix}
+ 0 & 0 & 0 \\
+ 0 & 0 & \sqrt{\gamma_3} \\
+ 0 & 0 & 0
\end{bmatrix}
- where :math:`\gamma_1 \in [0, 1]` and :math:`\gamma_2 \in [0, 1]` are the amplitude damping
- probabilities for subspaces (0,1) and (0,2) respectively.
+ where :math:`\gamma_1, \gamma_2, \gamma_3 \in [0, 1]` are the amplitude damping
+ probabilities for subspaces (0,1), (0,2), and (1,2) respectively.
.. note::
- The Kraus operators :math:`\{K_0, K_1, K_2\}` are adapted from [`1 `_] (Eq. 8).
+ When :math:`\gamma_3=0` then Kraus operators :math:`\{K_0, K_1, K_2\}` are adapted from
+ [`1 `_] (Eq. 8).
+
+ The Kraus operator :math:`K_3` represents the :math:`|2 \rangle \rightarrow |1 \rangle` transition which is more
+ likely on some devices [`2 `_] (Sec II.A).
+
+ To maintain normalization :math:`\gamma_2 + \gamma_3 \leq 1`.
+
**Details:**
* Number of wires: 1
- * Number of parameters: 2
+ * Number of parameters: 3
Args:
gamma_1 (float): :math:`|1 \rangle \rightarrow |0 \rangle` amplitude damping probability.
gamma_2 (float): :math:`|2 \rangle \rightarrow |0 \rangle` amplitude damping probability.
+ gamma_3 (float): :math:`|2 \rangle \rightarrow |1 \rangle` amplitude damping probability.
wires (Sequence[int] or int): the wire the channel acts on
id (str or None): String representing the operation (optional)
"""
- num_params = 2
+ num_params = 3
num_wires = 1
grad_method = "F"
- def __init__(self, gamma_1, gamma_2, wires, id=None):
- # Verify gamma_1 and gamma_2
- for gamma in (gamma_1, gamma_2):
- if not (math.is_abstract(gamma_1) or math.is_abstract(gamma_2)):
+ def __init__(self, gamma_1, gamma_2, gamma_3, wires, id=None):
+ # Verify input
+ for gamma in (gamma_1, gamma_2, gamma_3):
+ if not math.is_abstract(gamma):
if not 0.0 <= gamma <= 1.0:
raise ValueError("Each probability must be in the interval [0,1]")
- super().__init__(gamma_1, gamma_2, wires=wires, id=id)
+ if not (math.is_abstract(gamma_2) or math.is_abstract(gamma_3)):
+ if not 0.0 <= gamma_2 + gamma_3 <= 1.0:
+ raise ValueError(r"\gamma_2+\gamma_3 must be in the interval [0,1]")
+ super().__init__(gamma_1, gamma_2, gamma_3, wires=wires, id=id)
@staticmethod
- def compute_kraus_matrices(gamma_1, gamma_2): # pylint:disable=arguments-differ
+ def compute_kraus_matrices(gamma_1, gamma_2, gamma_3): # pylint:disable=arguments-differ
r"""Kraus matrices representing the ``QutritAmplitudeDamping`` channel.
Args:
gamma_1 (float): :math:`|1\rangle \rightarrow |0\rangle` amplitude damping probability.
gamma_2 (float): :math:`|2\rangle \rightarrow |0\rangle` amplitude damping probability.
+ gamma_3 (float): :math:`|2\rangle \rightarrow |1\rangle` amplitude damping probability.
Returns:
list(array): list of Kraus matrices
**Example**
- >>> qml.QutritAmplitudeDamping.compute_kraus_matrices(0.5, 0.25)
+ >>> qml.QutritAmplitudeDamping.compute_kraus_matrices(0.5, 0.25, 0.36)
[
array([ [1. , 0. , 0. ],
[0. , 0.70710678, 0. ],
- [0. , 0. , 0.8660254 ]]),
+ [0. , 0. , 0.6244998 ]]),
array([ [0. , 0.70710678, 0. ],
[0. , 0. , 0. ],
[0. , 0. , 0. ]]),
array([ [0. , 0. , 0.5 ],
[0. , 0. , 0. ],
[0. , 0. , 0. ]])
+ array([ [0. , 0. , 0. ],
+ [0. , 0. , 0.6 ],
+ [0. , 0. , 0. ]])
]
"""
- K0 = math.diag([1, math.sqrt(1 - gamma_1 + math.eps), math.sqrt(1 - gamma_2 + math.eps)])
+ K0 = math.diag(
+ [1, math.sqrt(1 - gamma_1 + math.eps), math.sqrt(1 - gamma_2 - gamma_3 + math.eps)]
+ )
K1 = math.sqrt(gamma_1 + math.eps) * math.convert_like(
math.cast_like(math.array([[0, 1, 0], [0, 0, 0], [0, 0, 0]]), gamma_1), gamma_1
)
K2 = math.sqrt(gamma_2 + math.eps) * math.convert_like(
math.cast_like(math.array([[0, 0, 1], [0, 0, 0], [0, 0, 0]]), gamma_2), gamma_2
)
- return [K0, K1, K2]
+ K3 = math.sqrt(gamma_3 + math.eps) * math.convert_like(
+ math.cast_like(math.array([[0, 0, 0], [0, 0, 1], [0, 0, 0]]), gamma_3), gamma_3
+ )
+ return [K0, K1, K2, K3]
diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py
index 57d6e89c28b..947a636b08b 100644
--- a/pennylane/workflow/qnode.py
+++ b/pennylane/workflow/qnode.py
@@ -1,4 +1,4 @@
-# Copyright 2018-2021 Xanadu Quantum Technologies Inc.
+# Copyright 2018-2024 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1081,7 +1081,7 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml
res, self._qfunc_output, self._tape.shots.has_partitioned_shots
)
- def __call__(self, *args, **kwargs) -> qml.typing.Result:
+ def _impl_call(self, *args, **kwargs) -> qml.typing.Result:
old_interface = self.interface
if old_interface == "auto":
@@ -1113,6 +1113,11 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result:
return res
+ def __call__(self, *args, **kwargs) -> qml.typing.Result:
+ if qml.capture.enabled():
+ return qml.capture.qnode_call(self, *args, **kwargs)
+ return self._impl_call(*args, **kwargs)
+
qnode = lambda device, **kwargs: functools.partial(QNode, device=device, **kwargs)
qnode.__doc__ = QNode.__doc__
diff --git a/tests/capture/test_capture_qnode.py b/tests/capture/test_capture_qnode.py
new file mode 100644
index 00000000000..1fe57c65e23
--- /dev/null
+++ b/tests/capture/test_capture_qnode.py
@@ -0,0 +1,298 @@
+# Copyright 2018-2024 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Tests for capturing a qnode into jaxpr.
+"""
+from functools import partial
+
+# pylint: disable=protected-access
+import pytest
+
+import pennylane as qml
+from pennylane.capture.capture_qnode import _get_qnode_prim
+
+qnode_prim = _get_qnode_prim()
+
+pytestmark = pytest.mark.jax
+
+jax = pytest.importorskip("jax")
+
+
+@pytest.fixture(autouse=True)
+def enable_disable_plxpr():
+ qml.capture.enable()
+ yield
+ qml.capture.disable()
+
+
+@pytest.mark.parametrize("dev_name", ("default.qubit", "default.qubit.legacy"))
+def test_error_if_shot_vector(dev_name):
+ """Test that a NotImplementedError is raised if a shot vector is provided."""
+
+ dev = qml.device(dev_name, wires=1, shots=(50, 50))
+
+ @qml.qnode(dev)
+ def circuit():
+ return qml.sample()
+
+ with pytest.raises(NotImplementedError, match="shot vectors are not yet supported"):
+ jax.make_jaxpr(circuit)()
+
+ with pytest.raises(NotImplementedError, match="shot vectors are not yet supported"):
+ circuit()
+
+ jax.make_jaxpr(partial(circuit, shots=50))() # should run fine
+ res = circuit(shots=50)
+ assert qml.math.allclose(res, jax.numpy.zeros((50,)))
+
+
+@pytest.mark.parametrize("dev_name", ("default.qubit", "default.qubit.legacy"))
+def test_error_if_overridden_shot_vector(dev_name):
+ """Test that a NotImplementedError is raised if a shot vector is provided on call."""
+
+ dev = qml.device(dev_name, wires=1)
+
+ @qml.qnode(dev)
+ def circuit():
+ return qml.sample()
+
+ with pytest.raises(NotImplementedError, match="shot vectors are not yet supported"):
+ jax.make_jaxpr(partial(circuit, shots=(1, 1, 1)))()
+
+
+def test_error_if_no_device_wires():
+ """Test that a NotImplementedError is raised if the device does not provide wires."""
+
+ dev = qml.device("default.qubit")
+
+ @qml.qnode(dev)
+ def circuit():
+ return qml.sample()
+
+ with pytest.raises(NotImplementedError, match="devices must specify wires"):
+ jax.make_jaxpr(circuit)()
+
+ with pytest.raises(NotImplementedError, match="devices must specify wires"):
+ circuit()
+
+
+@pytest.mark.parametrize("x64_mode", (True, False))
+def test_simple_qnode(x64_mode):
+ """Test capturing a qnode for a simple use."""
+
+ initial_mode = jax.config.jax_enable_x64
+ jax.config.update("jax_enable_x64", x64_mode)
+
+ dev = qml.device("default.qubit", wires=4)
+
+ @qml.qnode(dev)
+ def circuit(x):
+ qml.RX(x, wires=0)
+ return qml.expval(qml.Z(0))
+
+ res = circuit(0.5)
+ assert qml.math.allclose(res, jax.numpy.cos(0.5))
+
+ jaxpr = jax.make_jaxpr(circuit)(0.5)
+
+ assert len(jaxpr.eqns) == 1
+ eqn0 = jaxpr.eqns[0]
+
+ fdtype = jax.numpy.float64 if x64_mode else jax.numpy.float32
+
+ assert jaxpr.in_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)]
+
+ assert eqn0.primitive == qnode_prim
+ assert eqn0.invars[0].aval == jaxpr.in_avals[0]
+ assert jaxpr.out_avals[0] == jax.core.ShapedArray((), fdtype)
+
+ assert eqn0.params["device"] == dev
+ assert eqn0.params["shots"] == qml.measurements.Shots(None)
+ expected_kwargs = {"diff_method": "best"}
+ expected_kwargs.update(circuit.execute_kwargs)
+ assert eqn0.params["qnode_kwargs"] == expected_kwargs
+
+ qfunc_jaxpr = eqn0.params["qfunc_jaxpr"]
+ assert len(qfunc_jaxpr.eqns) == 3
+ assert qfunc_jaxpr.eqns[0].primitive == qml.RX._primitive
+ assert qfunc_jaxpr.eqns[1].primitive == qml.Z._primitive
+ assert qfunc_jaxpr.eqns[2].primitive == qml.measurements.ExpectationMP._obs_primitive
+
+ assert len(eqn0.outvars) == 1
+ assert eqn0.outvars[0].aval == jax.core.ShapedArray((), fdtype)
+
+ output = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.5)
+ assert qml.math.allclose(output[0], jax.numpy.cos(0.5))
+
+ jax.config.update("jax_enable_x64", initial_mode)
+
+
+@pytest.mark.parametrize("dev_name", ("default.qubit", "default.qubit.legacy"))
+@pytest.mark.parametrize("x64_mode", (True, False))
+def test_overriding_shots(dev_name, x64_mode):
+ """Test that the number of shots can be overridden on call."""
+ initial_mode = jax.config.jax_enable_x64
+ jax.config.update("jax_enable_x64", x64_mode)
+
+ dev = qml.device(dev_name, wires=1)
+
+ @qml.qnode(dev)
+ def circuit():
+ return qml.sample()
+
+ jaxpr = jax.make_jaxpr(partial(circuit, shots=50))()
+ assert len(jaxpr.eqns) == 1
+ eqn0 = jaxpr.eqns[0]
+
+ assert eqn0.primitive == qnode_prim
+ assert eqn0.params["device"] == dev
+ assert eqn0.params["shots"] == qml.measurements.Shots(50)
+ assert (
+ eqn0.params["qfunc_jaxpr"].eqns[0].primitive == qml.measurements.SampleMP._wires_primitive
+ )
+
+ assert eqn0.outvars[0].aval == jax.core.ShapedArray(
+ (50,), jax.numpy.int64 if x64_mode else jax.numpy.int32
+ )
+
+ res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts)
+ assert qml.math.allclose(res, jax.numpy.zeros((50,)))
+
+ jax.config.update("jax_enable_x64", initial_mode)
+
+
+def test_providing_keyword_argument():
+ """Test that keyword arguments can be provided to the qnode."""
+
+ @qml.qnode(qml.device("default.qubit", wires=1))
+ def circuit(*, n_iterations=0):
+ for _ in range(n_iterations):
+ qml.X(0)
+ return qml.probs()
+
+ jaxpr = jax.make_jaxpr(partial(circuit, n_iterations=3))()
+
+ assert jaxpr.eqns[0].primitive == qnode_prim
+
+ qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"]
+ for i in range(3):
+ assert qfunc_jaxpr.eqns[i].primitive == qml.PauliX._primitive
+ assert len(qfunc_jaxpr.eqns) == 4
+
+ res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts)
+ assert qml.math.allclose(res, jax.numpy.array([0, 1]))
+
+ res2 = circuit(n_iterations=4)
+ assert qml.math.allclose(res2, jax.numpy.array([1, 0]))
+
+
+@pytest.mark.parametrize("x64_mode", (True, False))
+def test_multiple_measurements(x64_mode):
+ """Test that the qnode can return multiple measurements."""
+ initial_mode = jax.config.jax_enable_x64
+ jax.config.update("jax_enable_x64", x64_mode)
+
+ @qml.qnode(qml.device("default.qubit", wires=3, shots=50))
+ def circuit():
+ return qml.sample(), qml.probs(wires=(0, 1)), qml.expval(qml.Z(0))
+
+ jaxpr = jax.make_jaxpr(circuit)()
+
+ qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"]
+
+ assert qfunc_jaxpr.eqns[0].primitive == qml.measurements.SampleMP._wires_primitive
+ assert qfunc_jaxpr.eqns[1].primitive == qml.measurements.ProbabilityMP._wires_primitive
+ assert qfunc_jaxpr.eqns[2].primitive == qml.Z._primitive
+ assert qfunc_jaxpr.eqns[3].primitive == qml.measurements.ExpectationMP._obs_primitive
+
+ assert jaxpr.out_avals[0] == jax.core.ShapedArray(
+ (50, 3), jax.numpy.int64 if x64_mode else jax.numpy.int32
+ )
+ assert jaxpr.out_avals[1] == jax.core.ShapedArray(
+ (4,), jax.numpy.float64 if x64_mode else jax.numpy.float32
+ )
+ assert jaxpr.out_avals[2] == jax.core.ShapedArray(
+ (), jax.numpy.float64 if x64_mode else jax.numpy.float32
+ )
+
+ res1, res2, res3 = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts)
+ assert qml.math.allclose(res1, jax.numpy.zeros((50, 3)))
+ assert qml.math.allclose(res2, jax.numpy.array([1, 0, 0, 0]))
+ assert qml.math.allclose(res3, 1.0)
+
+ res1, res2, res3 = circuit()
+ assert qml.math.allclose(res1, jax.numpy.zeros((50, 3)))
+ assert qml.math.allclose(res2, jax.numpy.array([1, 0, 0, 0]))
+ assert qml.math.allclose(res3, 1.0)
+
+ jax.config.update("jax_enable_x64", initial_mode)
+
+
+@pytest.mark.parametrize("x64_mode", (True, False))
+def test_complex_return_types(x64_mode):
+ """Test returning measurements with complex values."""
+
+ initial_mode = jax.config.jax_enable_x64
+ jax.config.update("jax_enable_x64", x64_mode)
+
+ @qml.qnode(qml.device("default.qubit", wires=3))
+ def circuit():
+ return qml.state(), qml.density_matrix(wires=(0, 1))
+
+ jaxpr = jax.make_jaxpr(circuit)()
+
+ qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"]
+
+ assert qfunc_jaxpr.eqns[0].primitive == qml.measurements.StateMP._wires_primitive
+ assert qfunc_jaxpr.eqns[1].primitive == qml.measurements.DensityMatrixMP._wires_primitive
+
+ assert jaxpr.out_avals[0] == jax.core.ShapedArray(
+ (8,), jax.numpy.complex128 if x64_mode else jax.numpy.complex64
+ )
+ assert jaxpr.out_avals[1] == jax.core.ShapedArray(
+ (4, 4), jax.numpy.complex128 if x64_mode else jax.numpy.complex64
+ )
+
+ jax.config.update("jax_enable_x64", initial_mode)
+
+
+def test_capture_qnode_kwargs():
+ """Test that qnode kwargs are captured as parameters."""
+
+ dev = qml.device("default.qubit", wires=3)
+
+ @qml.qnode(
+ dev,
+ diff_method="parameter-shift",
+ grad_on_execution=False,
+ cache=True,
+ cachesize=10,
+ max_diff=2,
+ )
+ def circuit():
+ return qml.expval(qml.Z(0))
+
+ jaxpr = jax.make_jaxpr(circuit)()
+
+ assert jaxpr.eqns[0].primitive == qnode_prim
+ expected = {
+ "diff_method": "parameter-shift",
+ "grad_on_execution": False,
+ "cache": True,
+ "cachesize": 10,
+ "max_diff": 2,
+ "max_expansion": 10,
+ "device_vjp": False,
+ }
+ assert jaxpr.eqns[0].params["qnode_kwargs"] == expected
diff --git a/tests/devices/qutrit_mixed/test_qutrit_mixed_preprocessing.py b/tests/devices/qutrit_mixed/test_qutrit_mixed_preprocessing.py
index 9574fdbec25..0d80adf2943 100644
--- a/tests/devices/qutrit_mixed/test_qutrit_mixed_preprocessing.py
+++ b/tests/devices/qutrit_mixed/test_qutrit_mixed_preprocessing.py
@@ -126,6 +126,7 @@ def test_measurement_is_swapped_out(self, mp_fn, mp_cls, shots):
(qml.Snapshot(), True),
(qml.TRX(1.1, 0), True),
(qml.QutritDepolarizingChannel(0.4, 0), True),
+ (qml.QutritAmplitudeDamping(0.1, 0.2, 0.12, 0), True),
],
)
def test_accepted_operator(self, op, expected):
diff --git a/tests/ops/qutrit/test_qutrit_channel_ops.py b/tests/ops/qutrit/test_qutrit_channel_ops.py
index 585f623f499..4b2612815ce 100644
--- a/tests/ops/qutrit/test_qutrit_channel_ops.py
+++ b/tests/ops/qutrit/test_qutrit_channel_ops.py
@@ -176,15 +176,15 @@ class TestQutritAmplitudeDamping:
def test_gamma_zero(self, tol):
"""Test gamma_1=gamma_2=0 gives correct Kraus matrices"""
- kraus_mats = qml.QutritAmplitudeDamping(0, 0, wires=0).kraus_matrices()
+ kraus_mats = qml.QutritAmplitudeDamping(0, 0, 0, wires=0).kraus_matrices()
assert np.allclose(kraus_mats[0], np.eye(3), atol=tol, rtol=0)
- assert np.allclose(kraus_mats[1], np.zeros((3, 3)), atol=tol, rtol=0)
- assert np.allclose(kraus_mats[2], np.zeros((3, 3)), atol=tol, rtol=0)
+ for kraus_mat in kraus_mats[1:]:
+ assert np.allclose(kraus_mat, np.zeros((3, 3)), atol=tol, rtol=0)
- @pytest.mark.parametrize("gamma1,gamma2", ((0.1, 0.2), (0.75, 0.75)))
- def test_gamma_arbitrary(self, gamma1, gamma2, tol):
- """Test the correct Kraus matrices are returned, also ensures that the sum of gammas can be over 1."""
- K_0 = np.diag((1, np.sqrt(1 - gamma1), np.sqrt(1 - gamma2)))
+ @pytest.mark.parametrize("gamma1,gamma2,gamma3", ((0.1, 0.2, 0.3), (0.75, 0.75, 0.25)))
+ def test_gamma_arbitrary(self, gamma1, gamma2, gamma3, tol):
+ """Test the correct Kraus matrices are returned."""
+ K_0 = np.diag((1, np.sqrt(1 - gamma1), np.sqrt(1 - gamma2 - gamma3)))
K_1 = np.zeros((3, 3))
K_1[0, 1] = np.sqrt(gamma1)
@@ -192,33 +192,48 @@ def test_gamma_arbitrary(self, gamma1, gamma2, tol):
K_2 = np.zeros((3, 3))
K_2[0, 2] = np.sqrt(gamma2)
- expected = [K_0, K_1, K_2]
- damping_channel = qml.QutritAmplitudeDamping(gamma1, gamma2, wires=0)
+ K_3 = np.zeros((3, 3))
+ K_3[1, 2] = np.sqrt(gamma3)
+
+ expected = [K_0, K_1, K_2, K_3]
+ damping_channel = qml.QutritAmplitudeDamping(gamma1, gamma2, gamma3, wires=0)
assert np.allclose(damping_channel.kraus_matrices(), expected, atol=tol, rtol=0)
- @pytest.mark.parametrize("gamma1,gamma2", ((1.5, 0.0), (0.0, 1.0 + math.eps)))
- def test_gamma_invalid_parameter(self, gamma1, gamma2):
- """Ensures that error is thrown when gamma_1 or gamma_2 are outside [0,1]"""
- with pytest.raises(ValueError, match="Each probability must be in the interval"):
- channel.QutritAmplitudeDamping(gamma1, gamma2, wires=0).kraus_matrices()
+ @pytest.mark.parametrize(
+ "gamma1,gamma2,gamma3",
+ (
+ (1.5, 0.0, 0.0),
+ (0.0, 1.0 + math.eps, 0.0),
+ (0.0, 0.0, 1.1),
+ (0.0, 0.33, 0.67 + math.eps),
+ ),
+ )
+ def test_gamma_invalid_parameter(self, gamma1, gamma2, gamma3):
+ """Ensures that error is thrown when gamma_1, gamma_2, gamma_3, or (gamma_2 + gamma_3) are outside [0,1]"""
+ with pytest.raises(ValueError, match="must be in the interval"):
+ channel.QutritAmplitudeDamping(gamma1, gamma2, gamma3, wires=0).kraus_matrices()
@staticmethod
- def expected_jac_fn(gamma_1, gamma_2):
+ def expected_jac_fn(gamma_1, gamma_2, gamma_3):
"""Gets the expected Jacobian of Kraus matrices"""
- partial_1 = [math.zeros((3, 3)) for _ in range(3)]
+ partial_1 = [math.zeros((3, 3)) for _ in range(4)]
partial_1[0][1, 1] = -1 / (2 * math.sqrt(1 - gamma_1))
partial_1[1][0, 1] = 1 / (2 * math.sqrt(gamma_1))
- partial_2 = [math.zeros((3, 3)) for _ in range(3)]
- partial_2[0][2, 2] = -1 / (2 * math.sqrt(1 - gamma_2))
+ partial_2 = [math.zeros((3, 3)) for _ in range(4)]
+ partial_2[0][2, 2] = -1 / (2 * math.sqrt(1 - gamma_2 - gamma_3))
partial_2[2][0, 2] = 1 / (2 * math.sqrt(gamma_2))
- return [partial_1, partial_2]
+ partial_3 = [math.zeros((3, 3)) for _ in range(4)]
+ partial_3[0][2, 2] = -1 / (2 * math.sqrt(1 - gamma_2 - gamma_3))
+ partial_3[3][1, 2] = 1 / (2 * math.sqrt(gamma_3))
+
+ return [partial_1, partial_2, partial_3]
@staticmethod
- def kraus_fn(gamma_1, gamma_2):
+ def kraus_fn(gamma_1, gamma_2, gamma_3):
"""Gets the Kraus matrices of QutritAmplitudeDamping channel, used for differentiation."""
- damping_channel = qml.QutritAmplitudeDamping(gamma_1, gamma_2, wires=0)
+ damping_channel = qml.QutritAmplitudeDamping(gamma_1, gamma_2, gamma_3, wires=0)
return math.stack(damping_channel.kraus_matrices())
@pytest.mark.autograd
@@ -226,8 +241,10 @@ def test_kraus_jac_autograd(self):
"""Tests Jacobian of Kraus matrices using autograd."""
gamma_1 = pnp.array(0.43, requires_grad=True)
gamma_2 = pnp.array(0.12, requires_grad=True)
- jac = qml.jacobian(self.kraus_fn)(gamma_1, gamma_2)
- assert math.allclose(jac, self.expected_jac_fn(gamma_1, gamma_2))
+ gamma_3 = pnp.array(0.35, requires_grad=True)
+
+ jac = qml.jacobian(self.kraus_fn)(gamma_1, gamma_2, gamma_3)
+ assert math.allclose(jac, self.expected_jac_fn(gamma_1, gamma_2, gamma_3))
@pytest.mark.torch
def test_kraus_jac_torch(self):
@@ -236,11 +253,15 @@ def test_kraus_jac_torch(self):
gamma_1 = torch.tensor(0.43, requires_grad=True)
gamma_2 = torch.tensor(0.12, requires_grad=True)
+ gamma_3 = torch.tensor(0.35, requires_grad=True)
- jac = torch.autograd.functional.jacobian(self.kraus_fn, (gamma_1, gamma_2))
- expected = self.expected_jac_fn(gamma_1.detach().numpy(), gamma_2.detach().numpy())
- assert math.allclose(jac[0].detach().numpy(), expected[0])
- assert math.allclose(jac[1].detach().numpy(), expected[1])
+ jac = torch.autograd.functional.jacobian(self.kraus_fn, (gamma_1, gamma_2, gamma_3))
+ expected = self.expected_jac_fn(
+ gamma_1.detach().numpy(), gamma_2.detach().numpy(), gamma_3.detach().numpy()
+ )
+
+ for res_partial, exp_partial in zip(jac, expected):
+ assert math.allclose(res_partial.detach().numpy(), exp_partial)
@pytest.mark.tf
def test_kraus_jac_tf(self):
@@ -249,10 +270,12 @@ def test_kraus_jac_tf(self):
gamma_1 = tf.Variable(0.43)
gamma_2 = tf.Variable(0.12)
+ gamma_3 = tf.Variable(0.35)
+
with tf.GradientTape() as tape:
- out = self.kraus_fn(gamma_1, gamma_2)
- jac = tape.jacobian(out, (gamma_1, gamma_2))
- assert math.allclose(jac, self.expected_jac_fn(gamma_1, gamma_2))
+ out = self.kraus_fn(gamma_1, gamma_2, gamma_3)
+ jac = tape.jacobian(out, (gamma_1, gamma_2, gamma_3))
+ assert math.allclose(jac, self.expected_jac_fn(gamma_1, gamma_2, gamma_3))
@pytest.mark.jax
def test_kraus_jac_jax(self):
@@ -261,5 +284,7 @@ def test_kraus_jac_jax(self):
gamma_1 = jax.numpy.array(0.43)
gamma_2 = jax.numpy.array(0.12)
- jac = jax.jacobian(self.kraus_fn, argnums=[0, 1])(gamma_1, gamma_2)
- assert math.allclose(jac, self.expected_jac_fn(gamma_1, gamma_2))
+ gamma_3 = jax.numpy.array(0.35)
+
+ jac = jax.jacobian(self.kraus_fn, argnums=[0, 1, 2])(gamma_1, gamma_2, gamma_3)
+ assert math.allclose(jac, self.expected_jac_fn(gamma_1, gamma_2, gamma_3))