diff --git a/.github/stable/all_interfaces.txt b/.github/stable/all_interfaces.txt index 77b45a0d419..05f6fb152b4 100644 --- a/.github/stable/all_interfaces.txt +++ b/.github/stable/all_interfaces.txt @@ -1,17 +1,18 @@ absl-py==2.1.0 appdirs==1.4.4 +astroid==2.6.6 astunparse==1.6.3 autograd==1.6.2 -autoray==0.6.9 +autoray==0.6.12 black==24.4.2 cachetools==5.3.3 -certifi==2024.2.2 +certifi==2024.6.2 cfgv==3.4.0 charset-normalizer==3.3.2 -clarabel==0.7.1 +clarabel==0.9.0 click==8.1.7 contourpy==1.2.1 -coverage==7.5.1 +coverage==7.5.3 cvxopt==1.3.2 cvxpy==1.5.1 cycler==0.12.1 @@ -22,12 +23,12 @@ execnet==2.1.1 filelock==3.14.0 flaky==3.8.1 flatbuffers==24.3.25 -fonttools==4.51.0 +fonttools==4.53.0 fsspec==2024.5.0 future==1.0.0 gast==0.5.4 google-pasta==0.2.0 -grpcio==1.63.0 +grpcio==1.64.0 h5py==3.11.0 identify==2.5.36 idna==3.7 @@ -40,22 +41,24 @@ jaxlib==0.4.23 Jinja2==3.1.4 keras==3.3.3 kiwisolver==1.4.5 +lazy-object-proxy==1.10.0 libclang==18.1.1 Markdown==3.6 markdown-it-py==3.0.0 MarkupSafe==2.1.5 matplotlib==3.9.0 +mccabe==0.6.1 mdurl==0.1.2 ml-dtypes==0.3.2 mpmath==1.3.0 mypy-extensions==1.0.0 namex==0.0.8 networkx==3.2.1 -nodeenv==1.8.0 +nodeenv==1.9.0 numpy==1.26.4 opt-einsum==3.3.0 optree==0.11.0 -osqp==0.6.5 +osqp==0.6.7 packaging==24.0 pathspec==0.12.1 PennyLane_Lightning==0.37.0 @@ -67,8 +70,9 @@ protobuf==4.25.3 py==1.11.0 py-cpuinfo==9.0.0 Pygments==2.18.0 +pylint==2.7.4 pyparsing==3.1.2 -pytest==8.2.0 +pytest==8.2.1 pytest-benchmark==4.0.0 pytest-cov==5.0.0 pytest-forked==1.6.0 @@ -78,14 +82,14 @@ python-dateutil==2.9.0.post0 pytorch-triton-rocm==2.3.0 PyYAML==6.0.1 qdldl==0.1.7.post2 -requests==2.31.0 +requests==2.32.3 rich==13.7.1 rustworkx==0.14.2 -scipy==1.11.4 -scs==3.2.4.post1 +scipy==1.12.0 +scs==3.2.4.post2 semantic-version==2.10.0 six==1.16.0 -sympy==1.12 +sympy==1.12.1 tensorboard==2.16.2 tensorboard-data-server==0.7.2 tensorflow==2.16.1 @@ -95,9 +99,9 @@ tf_keras==2.16.0 toml==0.10.2 tomli==2.0.1 torch==2.3.0+rocm6.0 -typing_extensions==4.11.0 +typing_extensions==4.12.1 urllib3==2.2.1 virtualenv==20.26.2 Werkzeug==3.0.3 -wrapt==1.16.0 -zipp==3.18.2 +wrapt==1.12.1 +zipp==3.19.1 diff --git a/.github/stable/core.txt b/.github/stable/core.txt index def68dbf5b7..2c9830a6d42 100644 --- a/.github/stable/core.txt +++ b/.github/stable/core.txt @@ -1,15 +1,16 @@ appdirs==1.4.4 +astroid==2.6.6 autograd==1.6.2 -autoray==0.6.9 +autoray==0.6.12 black==24.4.2 cachetools==5.3.3 -certifi==2024.2.2 +certifi==2024.6.2 cfgv==3.4.0 charset-normalizer==3.3.2 -clarabel==0.7.1 +clarabel==0.9.0 click==8.1.7 contourpy==1.2.1 -coverage==7.5.1 +coverage==7.5.3 cvxopt==1.3.2 cvxpy==1.5.1 cycler==0.12.1 @@ -19,7 +20,7 @@ exceptiongroup==1.2.1 execnet==2.1.1 filelock==3.14.0 flaky==3.8.1 -fonttools==4.51.0 +fonttools==4.53.0 future==1.0.0 identify==2.5.36 idna==3.7 @@ -27,12 +28,14 @@ importlib_resources==6.4.0 iniconfig==2.0.0 isort==5.13.2 kiwisolver==1.4.5 +lazy-object-proxy==1.10.0 matplotlib==3.9.0 +mccabe==0.6.1 mypy-extensions==1.0.0 networkx==3.2.1 -nodeenv==1.8.0 +nodeenv==1.9.0 numpy==1.26.4 -osqp==0.6.5 +osqp==0.6.7 packaging==24.0 pathspec==0.12.1 PennyLane_Lightning==0.37.0 @@ -42,8 +45,9 @@ pluggy==1.5.0 pre-commit==3.7.1 py==1.11.0 py-cpuinfo==9.0.0 +pylint==2.7.4 pyparsing==3.1.2 -pytest==8.2.0 +pytest==8.2.1 pytest-benchmark==4.0.0 pytest-cov==5.0.0 pytest-forked==1.6.0 @@ -53,15 +57,16 @@ pytest-xdist==3.6.1 python-dateutil==2.9.0.post0 PyYAML==6.0.1 qdldl==0.1.7.post2 -requests==2.31.0 +requests==2.32.3 rustworkx==0.14.2 scipy==1.11.4 -scs==3.2.4.post1 +scs==3.2.4.post2 semantic-version==2.10.0 six==1.16.0 toml==0.10.2 tomli==2.0.1 -typing_extensions==4.11.0 +typing_extensions==4.12.1 urllib3==2.2.1 virtualenv==20.26.2 -zipp==3.18.2 +wrapt==1.12.1 +zipp==3.19.1 diff --git a/.github/stable/doc.txt b/.github/stable/doc.txt index a7492e17485..e4c039b43e6 100644 --- a/.github/stable/doc.txt +++ b/.github/stable/doc.txt @@ -10,7 +10,7 @@ autograd==1.6.2 autoray==0.6.12 Babel==2.15.0 cachetools==5.3.3 -certifi==2024.2.2 +certifi==2024.6.2 charset-normalizer==3.3.2 cirq-core==1.3.0 contourpy==1.2.1 @@ -20,7 +20,7 @@ docutils==0.16 duet==0.2.9 exceptiongroup==1.2.1 flatbuffers==24.3.25 -fonttools==4.51.0 +fonttools==4.53.0 frozenlist==1.4.1 fsspec==2024.5.0 future==1.0.0 @@ -29,7 +29,7 @@ google-auth==2.29.0 google-auth-oauthlib==0.4.6 google-pasta==0.2.0 graphviz==0.20.3 -grpcio==1.63.0 +grpcio==1.64.0 h5py==3.11.0 idna==3.7 imagesize==1.4.1 @@ -72,8 +72,8 @@ pybtex-docutils==1.0.3 Pygments==2.18.0 pygments-github-lexers==0.0.5 pyparsing==3.1.2 -pyscf==2.5.0 -pytest==8.2.0 +pyscf==2.6.0 +pytest==8.2.1 python-dateutil==2.9.0.post0 pytz==2024.1 PyYAML==6.0.1 @@ -81,7 +81,7 @@ requests==2.28.2 requests-oauthlib==2.0.0 rsa==4.9 rustworkx==0.12.1 -scipy==1.13.0 +scipy==1.13.1 semantic-version==2.10.0 six==1.16.0 snowballstemmer==2.2.0 @@ -98,7 +98,7 @@ sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 sphinxext-opengraph==0.6.3 -sympy==1.12 +sympy==1.12.1 tensorboard==2.11.2 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.1 @@ -111,11 +111,11 @@ toml==0.10.2 tomli==2.0.1 torch==1.9.0+cpu tqdm==4.66.4 -typing_extensions==4.11.0 +typing_extensions==4.12.1 tzdata==2024.1 urllib3==1.26.18 Werkzeug==3.0.3 wrapt==1.16.0 xanadu-sphinx-theme==0.5.0 yarl==1.9.4 -zipp==3.18.2 +zipp==3.19.1 diff --git a/.github/stable/external.txt b/.github/stable/external.txt index febc739d8c6..89e5d1cf522 100644 --- a/.github/stable/external.txt +++ b/.github/stable/external.txt @@ -1,32 +1,35 @@ absl-py==2.1.0 -anyio==4.3.0 +anyio==4.4.0 appdirs==1.4.4 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 arrow==1.3.0 +astroid==2.6.6 asttokens==2.4.1 astunparse==1.6.3 async-lru==2.0.4 attrs==23.2.0 autograd==1.6.2 -autoray==0.6.9 +autoray==0.6.12 Babel==2.15.0 beautifulsoup4==4.12.3 black==24.4.2 bleach==6.1.0 cachetools==5.3.3 -certifi==2024.2.2 +certifi==2024.6.2 cffi==1.16.0 cfgv==3.4.0 charset-normalizer==3.3.2 -clarabel==0.7.1 +clarabel==0.9.0 click==8.1.7 comm==0.2.2 contourpy==1.2.1 -coverage==7.5.1 +cotengra==0.6.2 +coverage==7.5.3 cvxopt==1.3.2 cvxpy==1.5.1 cycler==0.12.1 +cytoolz==0.12.3 debugpy==1.8.1 decorator==5.1.1 defusedxml==0.7.1 @@ -40,12 +43,12 @@ fastjsonschema==2.19.1 filelock==3.14.0 flaky==3.8.1 flatbuffers==24.3.25 -fonttools==4.51.0 +fonttools==4.53.0 fqdn==1.5.1 future==1.0.0 gast==0.5.4 google-pasta==0.2.0 -grpcio==1.63.0 +grpcio==1.64.0 h11==0.14.0 h5py==3.11.0 httpcore==1.0.5 @@ -71,23 +74,26 @@ jsonschema==4.22.0 jsonschema-specifications==2023.12.1 jupyter-events==0.10.0 jupyter-lsp==2.2.5 -jupyter_client==8.6.1 +jupyter_client==8.6.2 jupyter_core==5.7.2 -jupyter_server==2.14.0 +jupyter_server==2.14.1 jupyter_server_terminals==0.5.3 -jupyterlab==4.1.8 +jupyterlab==4.2.1 jupyterlab-widgets==1.1.7 jupyterlab_pygments==0.3.0 -jupyterlab_server==2.27.1 +jupyterlab_server==2.27.2 keras==3.3.3 kiwisolver==1.4.5 lark==1.1.9 +lazy-object-proxy==1.10.0 libclang==18.1.1 +llvmlite==0.42.0 Markdown==3.6 markdown-it-py==3.0.0 MarkupSafe==2.1.5 matplotlib==3.9.0 matplotlib-inline==0.1.7 +mccabe==0.6.1 mdurl==0.1.2 mistune==3.0.2 ml-dtypes==0.3.2 @@ -98,13 +104,14 @@ nbconvert==7.16.4 nbformat==5.10.4 nest-asyncio==1.6.0 networkx==3.2.1 -nodeenv==1.8.0 -notebook==7.1.3 +nodeenv==1.9.0 +notebook==7.2.0 notebook_shim==0.2.4 +numba==0.59.1 numpy==1.26.4 opt-einsum==3.3.0 optree==0.11.0 -osqp==0.6.5 +osqp==0.6.7 overrides==7.7.0 packaging==24.0 pandocfilters==1.5.1 @@ -118,7 +125,7 @@ platformdirs==4.2.2 pluggy==1.5.0 pre-commit==3.7.1 prometheus_client==0.20.0 -prompt-toolkit==3.0.43 +prompt_toolkit==3.0.45 protobuf==4.25.3 psutil==5.9.8 ptyprocess==0.7.0 @@ -127,9 +134,10 @@ py==1.11.0 py-cpuinfo==9.0.0 pycparser==2.22 Pygments==2.18.0 +pylint==2.7.4 pyparsing==3.1.2 pyperclip==1.8.2 -pytest==8.2.0 +pytest==8.2.1 pytest-benchmark==4.0.0 pytest-cov==5.0.0 pytest-forked==1.6.0 @@ -141,15 +149,16 @@ PyYAML==6.0.1 pyzmq==26.0.3 pyzx==0.8.0 qdldl==0.1.7.post2 +quimb==1.8.1 referencing==0.35.1 -requests==2.31.0 +requests==2.32.3 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 rich==13.7.1 rpds-py==0.18.1 rustworkx==0.14.2 -scipy==1.11.4 -scs==3.2.4.post1 +scipy==1.12.0 +scs==3.2.4.post2 semantic-version==2.10.0 Send2Trash==1.8.3 six==1.16.0 @@ -168,11 +177,12 @@ tinycss2==1.3.0 toml==0.10.2 tomli==2.0.1 tomlkit==0.12.5 +toolz==0.12.1 tornado==6.4 tqdm==4.66.4 traitlets==5.14.3 types-python-dateutil==2.9.0.20240316 -typing_extensions==4.11.0 +typing_extensions==4.12.1 uri-template==1.3.0 urllib3==2.2.1 virtualenv==20.26.2 @@ -182,5 +192,5 @@ webencodings==0.5.1 websocket-client==1.8.0 Werkzeug==3.0.3 widgetsnbextension==3.6.6 -wrapt==1.16.0 -zipp==3.18.2 +wrapt==1.12.1 +zipp==3.19.1 diff --git a/.github/stable/jax.txt b/.github/stable/jax.txt index d3181d8b927..4978dada90c 100644 --- a/.github/stable/jax.txt +++ b/.github/stable/jax.txt @@ -1,15 +1,16 @@ appdirs==1.4.4 +astroid==2.6.6 autograd==1.6.2 -autoray==0.6.9 +autoray==0.6.12 black==24.4.2 cachetools==5.3.3 -certifi==2024.2.2 +certifi==2024.6.2 cfgv==3.4.0 charset-normalizer==3.3.2 -clarabel==0.7.1 +clarabel==0.9.0 click==8.1.7 contourpy==1.2.1 -coverage==7.5.1 +coverage==7.5.3 cvxopt==1.3.2 cvxpy==1.5.1 cycler==0.12.1 @@ -19,7 +20,7 @@ exceptiongroup==1.2.1 execnet==2.1.1 filelock==3.14.0 flaky==3.8.1 -fonttools==4.51.0 +fonttools==4.53.0 future==1.0.0 identify==2.5.36 idna==3.7 @@ -30,14 +31,16 @@ isort==5.13.2 jax==0.4.23 jaxlib==0.4.23 kiwisolver==1.4.5 +lazy-object-proxy==1.10.0 matplotlib==3.9.0 +mccabe==0.6.1 ml-dtypes==0.4.0 mypy-extensions==1.0.0 networkx==3.2.1 -nodeenv==1.8.0 +nodeenv==1.9.0 numpy==1.26.4 opt-einsum==3.3.0 -osqp==0.6.5 +osqp==0.6.7 packaging==24.0 pathspec==0.12.1 PennyLane_Lightning==0.37.0 @@ -47,8 +50,9 @@ pluggy==1.5.0 pre-commit==3.7.1 py==1.11.0 py-cpuinfo==9.0.0 +pylint==2.7.4 pyparsing==3.1.2 -pytest==8.2.0 +pytest==8.2.1 pytest-benchmark==4.0.0 pytest-cov==5.0.0 pytest-forked==1.6.0 @@ -58,15 +62,16 @@ pytest-xdist==3.6.1 python-dateutil==2.9.0.post0 PyYAML==6.0.1 qdldl==0.1.7.post2 -requests==2.31.0 +requests==2.32.3 rustworkx==0.14.2 -scipy==1.11.4 -scs==3.2.4.post1 +scipy==1.12.0 +scs==3.2.4.post2 semantic-version==2.10.0 six==1.16.0 toml==0.10.2 tomli==2.0.1 -typing_extensions==4.11.0 +typing_extensions==4.12.1 urllib3==2.2.1 virtualenv==20.26.2 -zipp==3.18.2 +wrapt==1.12.1 +zipp==3.19.1 diff --git a/.github/stable/tf.txt b/.github/stable/tf.txt index 41cdf783405..1bf9daf1825 100644 --- a/.github/stable/tf.txt +++ b/.github/stable/tf.txt @@ -1,17 +1,18 @@ absl-py==2.1.0 appdirs==1.4.4 +astroid==2.6.6 astunparse==1.6.3 autograd==1.6.2 -autoray==0.6.9 +autoray==0.6.12 black==24.4.2 cachetools==5.3.3 -certifi==2024.2.2 +certifi==2024.6.2 cfgv==3.4.0 charset-normalizer==3.3.2 -clarabel==0.7.1 +clarabel==0.9.0 click==8.1.7 contourpy==1.2.1 -coverage==7.5.1 +coverage==7.5.3 cvxopt==1.3.2 cvxpy==1.5.1 cycler==0.12.1 @@ -22,11 +23,11 @@ execnet==2.1.1 filelock==3.14.0 flaky==3.8.1 flatbuffers==24.3.25 -fonttools==4.51.0 +fonttools==4.53.0 future==1.0.0 gast==0.5.4 google-pasta==0.2.0 -grpcio==1.63.0 +grpcio==1.64.0 h5py==3.11.0 identify==2.5.36 idna==3.7 @@ -36,21 +37,23 @@ iniconfig==2.0.0 isort==5.13.2 keras==3.3.3 kiwisolver==1.4.5 +lazy-object-proxy==1.10.0 libclang==18.1.1 Markdown==3.6 markdown-it-py==3.0.0 MarkupSafe==2.1.5 matplotlib==3.9.0 +mccabe==0.6.1 mdurl==0.1.2 ml-dtypes==0.3.2 mypy-extensions==1.0.0 namex==0.0.8 networkx==3.2.1 -nodeenv==1.8.0 +nodeenv==1.9.0 numpy==1.26.4 opt-einsum==3.3.0 optree==0.11.0 -osqp==0.6.5 +osqp==0.6.7 packaging==24.0 pathspec==0.12.1 PennyLane_Lightning==0.37.0 @@ -62,8 +65,9 @@ protobuf==4.25.3 py==1.11.0 py-cpuinfo==9.0.0 Pygments==2.18.0 +pylint==2.7.4 pyparsing==3.1.2 -pytest==8.2.0 +pytest==8.2.1 pytest-benchmark==4.0.0 pytest-cov==5.0.0 pytest-forked==1.6.0 @@ -73,11 +77,11 @@ pytest-xdist==3.6.1 python-dateutil==2.9.0.post0 PyYAML==6.0.1 qdldl==0.1.7.post2 -requests==2.31.0 +requests==2.32.3 rich==13.7.1 rustworkx==0.14.2 scipy==1.11.4 -scs==3.2.4.post1 +scs==3.2.4.post2 semantic-version==2.10.0 six==1.16.0 tensorboard==2.16.2 @@ -88,9 +92,9 @@ termcolor==2.4.0 tf_keras==2.16.0 toml==0.10.2 tomli==2.0.1 -typing_extensions==4.11.0 +typing_extensions==4.12.1 urllib3==2.2.1 virtualenv==20.26.2 Werkzeug==3.0.3 -wrapt==1.16.0 -zipp==3.18.2 +wrapt==1.12.1 +zipp==3.19.1 diff --git a/.github/stable/torch.txt b/.github/stable/torch.txt index 1acf7336c96..18a7db0f881 100644 --- a/.github/stable/torch.txt +++ b/.github/stable/torch.txt @@ -1,15 +1,16 @@ appdirs==1.4.4 +astroid==2.6.6 autograd==1.6.2 -autoray==0.6.9 +autoray==0.6.12 black==24.4.2 cachetools==5.3.3 -certifi==2024.2.2 +certifi==2024.6.2 cfgv==3.4.0 charset-normalizer==3.3.2 -clarabel==0.7.1 +clarabel==0.9.0 click==8.1.7 contourpy==1.2.1 -coverage==7.5.1 +coverage==7.5.3 cvxopt==1.3.2 cvxpy==1.5.1 cycler==0.12.1 @@ -19,7 +20,7 @@ exceptiongroup==1.2.1 execnet==2.1.1 filelock==3.14.0 flaky==3.8.1 -fonttools==4.51.0 +fonttools==4.53.0 fsspec==2024.5.0 future==1.0.0 identify==2.5.36 @@ -29,14 +30,16 @@ iniconfig==2.0.0 isort==5.13.2 Jinja2==3.1.4 kiwisolver==1.4.5 +lazy-object-proxy==1.10.0 MarkupSafe==2.1.5 matplotlib==3.9.0 +mccabe==0.6.1 mpmath==1.3.0 mypy-extensions==1.0.0 networkx==3.2.1 -nodeenv==1.8.0 +nodeenv==1.9.0 numpy==1.26.4 -osqp==0.6.5 +osqp==0.6.7 packaging==24.0 pathspec==0.12.1 PennyLane_Lightning==0.37.0 @@ -46,8 +49,9 @@ pluggy==1.5.0 pre-commit==3.7.1 py==1.11.0 py-cpuinfo==9.0.0 +pylint==2.7.4 pyparsing==3.1.2 -pytest==8.2.0 +pytest==8.2.1 pytest-benchmark==4.0.0 pytest-cov==5.0.0 pytest-forked==1.6.0 @@ -57,17 +61,18 @@ python-dateutil==2.9.0.post0 pytorch-triton-rocm==2.3.0 PyYAML==6.0.1 qdldl==0.1.7.post2 -requests==2.31.0 +requests==2.32.3 rustworkx==0.14.2 scipy==1.11.4 -scs==3.2.4.post1 +scs==3.2.4.post2 semantic-version==2.10.0 six==1.16.0 -sympy==1.12 +sympy==1.12.1 toml==0.10.2 tomli==2.0.1 torch==2.3.0+rocm6.0 -typing_extensions==4.11.0 +typing_extensions==4.12.1 urllib3==2.2.1 virtualenv==20.26.2 -zipp==3.18.2 +wrapt==1.12.1 +zipp==3.19.1 diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 3d90194d325..6d198b3b145 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -108,9 +108,10 @@ `m = measure(0); qml.sample(m)`. [(#5673)](https://github.com/PennyLaneAI/pennylane/pull/5673) -* PennyLane operators and measurements can now automatically be captured as instructions in JAXPR. +* PennyLane operators, measurements, and QNodes can now automatically be captured as instructions in JAXPR. [(#5564)](https://github.com/PennyLaneAI/pennylane/pull/5564) [(#5511)](https://github.com/PennyLaneAI/pennylane/pull/5511) + [(#5708)](https://github.com/PennyLaneAI/pennylane/pull/5708) [(#5523)](https://github.com/PennyLaneAI/pennylane/pull/5523) * The `decompose` transform has an `error` kwarg to specify the type of error that should be raised, @@ -151,6 +152,10 @@ * ``qml.QutritDepolarizingChannel`` has been added, allowing for depolarizing noise to be simulated on the `default.qutrit.mixed` device. [(#5502)](https://github.com/PennyLaneAI/pennylane/pull/5502) +* `qml.QutritAmplitudeDamping` channel has been added, allowing for noise processes modelled by amplitude damping to be simulated on the `default.qutrit.mixed` device. + [(#5503)](https://github.com/PennyLaneAI/pennylane/pull/5503) + [(#5757)](https://github.com/PennyLaneAI/pennylane/pull/5757) +

Breaking changes 💔

* A custom decomposition can no longer be provided to `QDrift`. Instead, apply the operations in your custom @@ -176,9 +181,6 @@ * `Controlled.wires` does not include `self.work_wires` anymore. That can be accessed separately through `Controlled.work_wires`. Consequently, `Controlled.active_wires` has been removed in favour of the more common `Controlled.wires`. [(#5728)](https://github.com/PennyLaneAI/pennylane/pull/5728) - -* `qml.QutritAmplitudeDamping` channel has been added, allowing for noise processes modelled by amplitude damping to be simulated on the `default.qutrit.mixed` device. - [(#5503)](https://github.com/PennyLaneAI/pennylane/pull/5503)

Deprecations 👋

diff --git a/pennylane/__init__.py b/pennylane/__init__.py index 80419832ae0..139ebc29222 100644 --- a/pennylane/__init__.py +++ b/pennylane/__init__.py @@ -218,6 +218,9 @@ def device(name, *args, **kwargs): * :mod:`'default.qutrit' `: a simple state simulator of qutrit-based quantum circuit architectures. + * :mod:`'default.qutrit.mixed' `: a + mixed-state simulator of qutrit-based quantum circuit architectures. + * :mod:`'default.gaussian' `: a simple simulator of Gaussian states and operations on continuous-variable circuit architectures. diff --git a/pennylane/capture/__init__.py b/pennylane/capture/__init__.py index 33edf07dbe6..3cbb39a3dfa 100644 --- a/pennylane/capture/__init__.py +++ b/pennylane/capture/__init__.py @@ -33,6 +33,7 @@ ~create_measurement_obs_primitive ~create_measurement_wires_primitive ~create_measurement_mcm_primitive + ~qnode_call To activate and deactivate the new PennyLane program capturing mechanism, use the switches ``qml.capture.enable`` and ``qml.capture.disable``. @@ -132,6 +133,7 @@ def _(*args, **kwargs): create_measurement_wires_primitive, create_measurement_mcm_primitive, ) +from .capture_qnode import qnode_call def __getattr__(key): diff --git a/pennylane/capture/capture_qnode.py b/pennylane/capture/capture_qnode.py new file mode 100644 index 00000000000..bbcd731e934 --- /dev/null +++ b/pennylane/capture/capture_qnode.py @@ -0,0 +1,167 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This submodule defines a capture compatible call to QNodes. +""" + +from functools import lru_cache, partial + +import pennylane as qml + +has_jax = True +try: + import jax +except ImportError: + has_jax = False + + +def _get_shapes_for(*measurements, shots=None, num_device_wires=0): + if jax.config.jax_enable_x64: + dtype_map = { + float: jax.numpy.float64, + int: jax.numpy.int64, + complex: jax.numpy.complex128, + } + else: + dtype_map = { + float: jax.numpy.float32, + int: jax.numpy.int32, + complex: jax.numpy.complex64, + } + + shapes = [] + if not shots: + shots = [None] + + for s in shots: + for m in measurements: + shape, dtype = m.abstract_eval(shots=s, num_device_wires=num_device_wires) + shapes.append(jax.core.ShapedArray(shape, dtype_map.get(dtype, dtype))) + return shapes + + +@lru_cache() +def _get_qnode_prim(): + if not has_jax: + return None + qnode_prim = jax.core.Primitive("qnode") + qnode_prim.multiple_results = True + + @qnode_prim.def_impl + def _(*args, shots, device, qnode_kwargs, qfunc_jaxpr): + def qfunc(*inner_args): + return jax.core.eval_jaxpr(qfunc_jaxpr.jaxpr, qfunc_jaxpr.consts, *inner_args) + + qnode = qml.QNode(qfunc, device, **qnode_kwargs) + return qnode._impl_call(*args, shots=shots) # pylint: disable=protected-access + + # pylint: disable=unused-argument + @qnode_prim.def_abstract_eval + def _(*args, shots, device, qnode_kwargs, qfunc_jaxpr): + mps = qfunc_jaxpr.out_avals + return _get_shapes_for(*mps, shots=shots, num_device_wires=len(device.wires)) + + return qnode_prim + + +# pylint: disable=protected-access +def _get_device_shots(device) -> "qml.measurements.Shots": + if isinstance(device, qml.devices.LegacyDevice): + if device._shot_vector: + return qml.measurements.Shots(device._raw_shot_sequence) + return qml.measurements.Shots(device.shots) + return device.shots + + +def qnode_call(qnode: "qml.QNode", *args, **kwargs) -> "qml.typing.Result": + """A capture compatible call to a QNode. This function is internally used by ``QNode.__call__``. + + Args: + qnode (QNode): a QNode + args: the arguments the QNode is called with + + Keyword Args: + Any keyword arguments accepted by the quantum function + + Returns: + qml.typing.Result: the result of a qnode execution + + **Example:** + + .. code-block:: python + + @qml.qnode(qml.device('lightning.qubit', wires=1)) + def circuit(x): + qml.RX(x, wires=0) + return qml.expval(qml.Z(0)), qml.probs() + + def f(x): + expval_z, probs = circuit(np.pi * x, shots=50) + return 2*expval_z + probs + + jaxpr = jax.make_jaxpr(f)(0.1) + print("jaxpr:") + print(jaxpr) + + res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.7) + print() + print("result:") + print(res) + + + .. code-block:: none + + jaxpr: + { lambda ; a:f32[]. let + b:f32[] = mul 3.141592653589793 a + c:f32[] d:f32[2] = qnode[ + device= + qfunc_jaxpr={ lambda ; e:f32[]. let + _:AbstractOperator() = RX[n_wires=1] e 0 + f:AbstractOperator() = PauliZ[n_wires=1] 0 + g:AbstractMeasurement(n_wires=None) = expval_obs f + h:AbstractMeasurement(n_wires=0) = probs_wires + in (g, h) } + qnode_kwargs={'diff_method': 'best', 'grad_on_execution': 'best', 'cache': False, 'cachesize': 10000, 'max_diff': 1, 'max_expansion': 10, 'device_vjp': False} + shots=Shots(total=50) + ] b + i:f32[] = mul 2.0 c + j:f32[2] = add i d + in (j,) } + + result: + [Array([-1.3 , -0.74], dtype=float32)] + + + """ + if "shots" in kwargs: + shots = qml.measurements.Shots(kwargs.pop("shots")) + else: + shots = _get_device_shots(qnode.device) + if shots.has_partitioned_shots: + # Questions over the pytrees and the nested result object shape + raise NotImplementedError("shot vectors are not yet supported with plxpr capture.") + + if not qnode.device.wires: + raise NotImplementedError("devices must specify wires for integration with plxpr capture.") + + qfunc = partial(qnode.func, **kwargs) if kwargs else qnode.func + + qfunc_jaxpr = jax.make_jaxpr(qfunc)(*args) + qnode_kwargs = {"diff_method": qnode.diff_method, **qnode.execute_kwargs} + qnode_prim = _get_qnode_prim() + + return qnode_prim.bind( + *args, shots=shots, device=qnode.device, qnode_kwargs=qnode_kwargs, qfunc_jaxpr=qfunc_jaxpr + ) diff --git a/pennylane/ops/qutrit/channel.py b/pennylane/ops/qutrit/channel.py index c50572cf206..82f5d9c4407 100644 --- a/pennylane/ops/qutrit/channel.py +++ b/pennylane/ops/qutrit/channel.py @@ -235,8 +235,10 @@ class QutritAmplitudeDamping(Channel): K_0 = \begin{bmatrix} 1 & 0 & 0\\ 0 & \sqrt{1-\gamma_1} & 0 \\ - 0 & 0 & \sqrt{1-\gamma_2} - \end{bmatrix}, \quad + 0 & 0 & \sqrt{1-(\gamma_2+\gamma_3)} + \end{bmatrix} + + .. math:: K_1 = \begin{bmatrix} 0 & \sqrt{\gamma_1} & 0 \\ 0 & 0 & 0 \\ @@ -246,70 +248,95 @@ class QutritAmplitudeDamping(Channel): 0 & 0 & \sqrt{\gamma_2} \\ 0 & 0 & 0 \\ 0 & 0 & 0 + \end{bmatrix}, \quad + K_3 = \begin{bmatrix} + 0 & 0 & 0 \\ + 0 & 0 & \sqrt{\gamma_3} \\ + 0 & 0 & 0 \end{bmatrix} - where :math:`\gamma_1 \in [0, 1]` and :math:`\gamma_2 \in [0, 1]` are the amplitude damping - probabilities for subspaces (0,1) and (0,2) respectively. + where :math:`\gamma_1, \gamma_2, \gamma_3 \in [0, 1]` are the amplitude damping + probabilities for subspaces (0,1), (0,2), and (1,2) respectively. .. note:: - The Kraus operators :math:`\{K_0, K_1, K_2\}` are adapted from [`1 `_] (Eq. 8). + When :math:`\gamma_3=0` then Kraus operators :math:`\{K_0, K_1, K_2\}` are adapted from + [`1 `_] (Eq. 8). + + The Kraus operator :math:`K_3` represents the :math:`|2 \rangle \rightarrow |1 \rangle` transition which is more + likely on some devices [`2 `_] (Sec II.A). + + To maintain normalization :math:`\gamma_2 + \gamma_3 \leq 1`. + **Details:** * Number of wires: 1 - * Number of parameters: 2 + * Number of parameters: 3 Args: gamma_1 (float): :math:`|1 \rangle \rightarrow |0 \rangle` amplitude damping probability. gamma_2 (float): :math:`|2 \rangle \rightarrow |0 \rangle` amplitude damping probability. + gamma_3 (float): :math:`|2 \rangle \rightarrow |1 \rangle` amplitude damping probability. wires (Sequence[int] or int): the wire the channel acts on id (str or None): String representing the operation (optional) """ - num_params = 2 + num_params = 3 num_wires = 1 grad_method = "F" - def __init__(self, gamma_1, gamma_2, wires, id=None): - # Verify gamma_1 and gamma_2 - for gamma in (gamma_1, gamma_2): - if not (math.is_abstract(gamma_1) or math.is_abstract(gamma_2)): + def __init__(self, gamma_1, gamma_2, gamma_3, wires, id=None): + # Verify input + for gamma in (gamma_1, gamma_2, gamma_3): + if not math.is_abstract(gamma): if not 0.0 <= gamma <= 1.0: raise ValueError("Each probability must be in the interval [0,1]") - super().__init__(gamma_1, gamma_2, wires=wires, id=id) + if not (math.is_abstract(gamma_2) or math.is_abstract(gamma_3)): + if not 0.0 <= gamma_2 + gamma_3 <= 1.0: + raise ValueError(r"\gamma_2+\gamma_3 must be in the interval [0,1]") + super().__init__(gamma_1, gamma_2, gamma_3, wires=wires, id=id) @staticmethod - def compute_kraus_matrices(gamma_1, gamma_2): # pylint:disable=arguments-differ + def compute_kraus_matrices(gamma_1, gamma_2, gamma_3): # pylint:disable=arguments-differ r"""Kraus matrices representing the ``QutritAmplitudeDamping`` channel. Args: gamma_1 (float): :math:`|1\rangle \rightarrow |0\rangle` amplitude damping probability. gamma_2 (float): :math:`|2\rangle \rightarrow |0\rangle` amplitude damping probability. + gamma_3 (float): :math:`|2\rangle \rightarrow |1\rangle` amplitude damping probability. Returns: list(array): list of Kraus matrices **Example** - >>> qml.QutritAmplitudeDamping.compute_kraus_matrices(0.5, 0.25) + >>> qml.QutritAmplitudeDamping.compute_kraus_matrices(0.5, 0.25, 0.36) [ array([ [1. , 0. , 0. ], [0. , 0.70710678, 0. ], - [0. , 0. , 0.8660254 ]]), + [0. , 0. , 0.6244998 ]]), array([ [0. , 0.70710678, 0. ], [0. , 0. , 0. ], [0. , 0. , 0. ]]), array([ [0. , 0. , 0.5 ], [0. , 0. , 0. ], [0. , 0. , 0. ]]) + array([ [0. , 0. , 0. ], + [0. , 0. , 0.6 ], + [0. , 0. , 0. ]]) ] """ - K0 = math.diag([1, math.sqrt(1 - gamma_1 + math.eps), math.sqrt(1 - gamma_2 + math.eps)]) + K0 = math.diag( + [1, math.sqrt(1 - gamma_1 + math.eps), math.sqrt(1 - gamma_2 - gamma_3 + math.eps)] + ) K1 = math.sqrt(gamma_1 + math.eps) * math.convert_like( math.cast_like(math.array([[0, 1, 0], [0, 0, 0], [0, 0, 0]]), gamma_1), gamma_1 ) K2 = math.sqrt(gamma_2 + math.eps) * math.convert_like( math.cast_like(math.array([[0, 0, 1], [0, 0, 0], [0, 0, 0]]), gamma_2), gamma_2 ) - return [K0, K1, K2] + K3 = math.sqrt(gamma_3 + math.eps) * math.convert_like( + math.cast_like(math.array([[0, 0, 0], [0, 0, 1], [0, 0, 0]]), gamma_3), gamma_3 + ) + return [K0, K1, K2, K3] diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 57d6e89c28b..947a636b08b 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -1,4 +1,4 @@ -# Copyright 2018-2021 Xanadu Quantum Technologies Inc. +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -1081,7 +1081,7 @@ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml res, self._qfunc_output, self._tape.shots.has_partitioned_shots ) - def __call__(self, *args, **kwargs) -> qml.typing.Result: + def _impl_call(self, *args, **kwargs) -> qml.typing.Result: old_interface = self.interface if old_interface == "auto": @@ -1113,6 +1113,11 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result: return res + def __call__(self, *args, **kwargs) -> qml.typing.Result: + if qml.capture.enabled(): + return qml.capture.qnode_call(self, *args, **kwargs) + return self._impl_call(*args, **kwargs) + qnode = lambda device, **kwargs: functools.partial(QNode, device=device, **kwargs) qnode.__doc__ = QNode.__doc__ diff --git a/tests/capture/test_capture_qnode.py b/tests/capture/test_capture_qnode.py new file mode 100644 index 00000000000..1fe57c65e23 --- /dev/null +++ b/tests/capture/test_capture_qnode.py @@ -0,0 +1,298 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for capturing a qnode into jaxpr. +""" +from functools import partial + +# pylint: disable=protected-access +import pytest + +import pennylane as qml +from pennylane.capture.capture_qnode import _get_qnode_prim + +qnode_prim = _get_qnode_prim() + +pytestmark = pytest.mark.jax + +jax = pytest.importorskip("jax") + + +@pytest.fixture(autouse=True) +def enable_disable_plxpr(): + qml.capture.enable() + yield + qml.capture.disable() + + +@pytest.mark.parametrize("dev_name", ("default.qubit", "default.qubit.legacy")) +def test_error_if_shot_vector(dev_name): + """Test that a NotImplementedError is raised if a shot vector is provided.""" + + dev = qml.device(dev_name, wires=1, shots=(50, 50)) + + @qml.qnode(dev) + def circuit(): + return qml.sample() + + with pytest.raises(NotImplementedError, match="shot vectors are not yet supported"): + jax.make_jaxpr(circuit)() + + with pytest.raises(NotImplementedError, match="shot vectors are not yet supported"): + circuit() + + jax.make_jaxpr(partial(circuit, shots=50))() # should run fine + res = circuit(shots=50) + assert qml.math.allclose(res, jax.numpy.zeros((50,))) + + +@pytest.mark.parametrize("dev_name", ("default.qubit", "default.qubit.legacy")) +def test_error_if_overridden_shot_vector(dev_name): + """Test that a NotImplementedError is raised if a shot vector is provided on call.""" + + dev = qml.device(dev_name, wires=1) + + @qml.qnode(dev) + def circuit(): + return qml.sample() + + with pytest.raises(NotImplementedError, match="shot vectors are not yet supported"): + jax.make_jaxpr(partial(circuit, shots=(1, 1, 1)))() + + +def test_error_if_no_device_wires(): + """Test that a NotImplementedError is raised if the device does not provide wires.""" + + dev = qml.device("default.qubit") + + @qml.qnode(dev) + def circuit(): + return qml.sample() + + with pytest.raises(NotImplementedError, match="devices must specify wires"): + jax.make_jaxpr(circuit)() + + with pytest.raises(NotImplementedError, match="devices must specify wires"): + circuit() + + +@pytest.mark.parametrize("x64_mode", (True, False)) +def test_simple_qnode(x64_mode): + """Test capturing a qnode for a simple use.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + dev = qml.device("default.qubit", wires=4) + + @qml.qnode(dev) + def circuit(x): + qml.RX(x, wires=0) + return qml.expval(qml.Z(0)) + + res = circuit(0.5) + assert qml.math.allclose(res, jax.numpy.cos(0.5)) + + jaxpr = jax.make_jaxpr(circuit)(0.5) + + assert len(jaxpr.eqns) == 1 + eqn0 = jaxpr.eqns[0] + + fdtype = jax.numpy.float64 if x64_mode else jax.numpy.float32 + + assert jaxpr.in_avals == [jax.core.ShapedArray((), fdtype, weak_type=True)] + + assert eqn0.primitive == qnode_prim + assert eqn0.invars[0].aval == jaxpr.in_avals[0] + assert jaxpr.out_avals[0] == jax.core.ShapedArray((), fdtype) + + assert eqn0.params["device"] == dev + assert eqn0.params["shots"] == qml.measurements.Shots(None) + expected_kwargs = {"diff_method": "best"} + expected_kwargs.update(circuit.execute_kwargs) + assert eqn0.params["qnode_kwargs"] == expected_kwargs + + qfunc_jaxpr = eqn0.params["qfunc_jaxpr"] + assert len(qfunc_jaxpr.eqns) == 3 + assert qfunc_jaxpr.eqns[0].primitive == qml.RX._primitive + assert qfunc_jaxpr.eqns[1].primitive == qml.Z._primitive + assert qfunc_jaxpr.eqns[2].primitive == qml.measurements.ExpectationMP._obs_primitive + + assert len(eqn0.outvars) == 1 + assert eqn0.outvars[0].aval == jax.core.ShapedArray((), fdtype) + + output = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.5) + assert qml.math.allclose(output[0], jax.numpy.cos(0.5)) + + jax.config.update("jax_enable_x64", initial_mode) + + +@pytest.mark.parametrize("dev_name", ("default.qubit", "default.qubit.legacy")) +@pytest.mark.parametrize("x64_mode", (True, False)) +def test_overriding_shots(dev_name, x64_mode): + """Test that the number of shots can be overridden on call.""" + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + dev = qml.device(dev_name, wires=1) + + @qml.qnode(dev) + def circuit(): + return qml.sample() + + jaxpr = jax.make_jaxpr(partial(circuit, shots=50))() + assert len(jaxpr.eqns) == 1 + eqn0 = jaxpr.eqns[0] + + assert eqn0.primitive == qnode_prim + assert eqn0.params["device"] == dev + assert eqn0.params["shots"] == qml.measurements.Shots(50) + assert ( + eqn0.params["qfunc_jaxpr"].eqns[0].primitive == qml.measurements.SampleMP._wires_primitive + ) + + assert eqn0.outvars[0].aval == jax.core.ShapedArray( + (50,), jax.numpy.int64 if x64_mode else jax.numpy.int32 + ) + + res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts) + assert qml.math.allclose(res, jax.numpy.zeros((50,))) + + jax.config.update("jax_enable_x64", initial_mode) + + +def test_providing_keyword_argument(): + """Test that keyword arguments can be provided to the qnode.""" + + @qml.qnode(qml.device("default.qubit", wires=1)) + def circuit(*, n_iterations=0): + for _ in range(n_iterations): + qml.X(0) + return qml.probs() + + jaxpr = jax.make_jaxpr(partial(circuit, n_iterations=3))() + + assert jaxpr.eqns[0].primitive == qnode_prim + + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + for i in range(3): + assert qfunc_jaxpr.eqns[i].primitive == qml.PauliX._primitive + assert len(qfunc_jaxpr.eqns) == 4 + + res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts) + assert qml.math.allclose(res, jax.numpy.array([0, 1])) + + res2 = circuit(n_iterations=4) + assert qml.math.allclose(res2, jax.numpy.array([1, 0])) + + +@pytest.mark.parametrize("x64_mode", (True, False)) +def test_multiple_measurements(x64_mode): + """Test that the qnode can return multiple measurements.""" + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + @qml.qnode(qml.device("default.qubit", wires=3, shots=50)) + def circuit(): + return qml.sample(), qml.probs(wires=(0, 1)), qml.expval(qml.Z(0)) + + jaxpr = jax.make_jaxpr(circuit)() + + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + + assert qfunc_jaxpr.eqns[0].primitive == qml.measurements.SampleMP._wires_primitive + assert qfunc_jaxpr.eqns[1].primitive == qml.measurements.ProbabilityMP._wires_primitive + assert qfunc_jaxpr.eqns[2].primitive == qml.Z._primitive + assert qfunc_jaxpr.eqns[3].primitive == qml.measurements.ExpectationMP._obs_primitive + + assert jaxpr.out_avals[0] == jax.core.ShapedArray( + (50, 3), jax.numpy.int64 if x64_mode else jax.numpy.int32 + ) + assert jaxpr.out_avals[1] == jax.core.ShapedArray( + (4,), jax.numpy.float64 if x64_mode else jax.numpy.float32 + ) + assert jaxpr.out_avals[2] == jax.core.ShapedArray( + (), jax.numpy.float64 if x64_mode else jax.numpy.float32 + ) + + res1, res2, res3 = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts) + assert qml.math.allclose(res1, jax.numpy.zeros((50, 3))) + assert qml.math.allclose(res2, jax.numpy.array([1, 0, 0, 0])) + assert qml.math.allclose(res3, 1.0) + + res1, res2, res3 = circuit() + assert qml.math.allclose(res1, jax.numpy.zeros((50, 3))) + assert qml.math.allclose(res2, jax.numpy.array([1, 0, 0, 0])) + assert qml.math.allclose(res3, 1.0) + + jax.config.update("jax_enable_x64", initial_mode) + + +@pytest.mark.parametrize("x64_mode", (True, False)) +def test_complex_return_types(x64_mode): + """Test returning measurements with complex values.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + + @qml.qnode(qml.device("default.qubit", wires=3)) + def circuit(): + return qml.state(), qml.density_matrix(wires=(0, 1)) + + jaxpr = jax.make_jaxpr(circuit)() + + qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + + assert qfunc_jaxpr.eqns[0].primitive == qml.measurements.StateMP._wires_primitive + assert qfunc_jaxpr.eqns[1].primitive == qml.measurements.DensityMatrixMP._wires_primitive + + assert jaxpr.out_avals[0] == jax.core.ShapedArray( + (8,), jax.numpy.complex128 if x64_mode else jax.numpy.complex64 + ) + assert jaxpr.out_avals[1] == jax.core.ShapedArray( + (4, 4), jax.numpy.complex128 if x64_mode else jax.numpy.complex64 + ) + + jax.config.update("jax_enable_x64", initial_mode) + + +def test_capture_qnode_kwargs(): + """Test that qnode kwargs are captured as parameters.""" + + dev = qml.device("default.qubit", wires=3) + + @qml.qnode( + dev, + diff_method="parameter-shift", + grad_on_execution=False, + cache=True, + cachesize=10, + max_diff=2, + ) + def circuit(): + return qml.expval(qml.Z(0)) + + jaxpr = jax.make_jaxpr(circuit)() + + assert jaxpr.eqns[0].primitive == qnode_prim + expected = { + "diff_method": "parameter-shift", + "grad_on_execution": False, + "cache": True, + "cachesize": 10, + "max_diff": 2, + "max_expansion": 10, + "device_vjp": False, + } + assert jaxpr.eqns[0].params["qnode_kwargs"] == expected diff --git a/tests/devices/qutrit_mixed/test_qutrit_mixed_preprocessing.py b/tests/devices/qutrit_mixed/test_qutrit_mixed_preprocessing.py index 9574fdbec25..0d80adf2943 100644 --- a/tests/devices/qutrit_mixed/test_qutrit_mixed_preprocessing.py +++ b/tests/devices/qutrit_mixed/test_qutrit_mixed_preprocessing.py @@ -126,6 +126,7 @@ def test_measurement_is_swapped_out(self, mp_fn, mp_cls, shots): (qml.Snapshot(), True), (qml.TRX(1.1, 0), True), (qml.QutritDepolarizingChannel(0.4, 0), True), + (qml.QutritAmplitudeDamping(0.1, 0.2, 0.12, 0), True), ], ) def test_accepted_operator(self, op, expected): diff --git a/tests/ops/qutrit/test_qutrit_channel_ops.py b/tests/ops/qutrit/test_qutrit_channel_ops.py index 585f623f499..4b2612815ce 100644 --- a/tests/ops/qutrit/test_qutrit_channel_ops.py +++ b/tests/ops/qutrit/test_qutrit_channel_ops.py @@ -176,15 +176,15 @@ class TestQutritAmplitudeDamping: def test_gamma_zero(self, tol): """Test gamma_1=gamma_2=0 gives correct Kraus matrices""" - kraus_mats = qml.QutritAmplitudeDamping(0, 0, wires=0).kraus_matrices() + kraus_mats = qml.QutritAmplitudeDamping(0, 0, 0, wires=0).kraus_matrices() assert np.allclose(kraus_mats[0], np.eye(3), atol=tol, rtol=0) - assert np.allclose(kraus_mats[1], np.zeros((3, 3)), atol=tol, rtol=0) - assert np.allclose(kraus_mats[2], np.zeros((3, 3)), atol=tol, rtol=0) + for kraus_mat in kraus_mats[1:]: + assert np.allclose(kraus_mat, np.zeros((3, 3)), atol=tol, rtol=0) - @pytest.mark.parametrize("gamma1,gamma2", ((0.1, 0.2), (0.75, 0.75))) - def test_gamma_arbitrary(self, gamma1, gamma2, tol): - """Test the correct Kraus matrices are returned, also ensures that the sum of gammas can be over 1.""" - K_0 = np.diag((1, np.sqrt(1 - gamma1), np.sqrt(1 - gamma2))) + @pytest.mark.parametrize("gamma1,gamma2,gamma3", ((0.1, 0.2, 0.3), (0.75, 0.75, 0.25))) + def test_gamma_arbitrary(self, gamma1, gamma2, gamma3, tol): + """Test the correct Kraus matrices are returned.""" + K_0 = np.diag((1, np.sqrt(1 - gamma1), np.sqrt(1 - gamma2 - gamma3))) K_1 = np.zeros((3, 3)) K_1[0, 1] = np.sqrt(gamma1) @@ -192,33 +192,48 @@ def test_gamma_arbitrary(self, gamma1, gamma2, tol): K_2 = np.zeros((3, 3)) K_2[0, 2] = np.sqrt(gamma2) - expected = [K_0, K_1, K_2] - damping_channel = qml.QutritAmplitudeDamping(gamma1, gamma2, wires=0) + K_3 = np.zeros((3, 3)) + K_3[1, 2] = np.sqrt(gamma3) + + expected = [K_0, K_1, K_2, K_3] + damping_channel = qml.QutritAmplitudeDamping(gamma1, gamma2, gamma3, wires=0) assert np.allclose(damping_channel.kraus_matrices(), expected, atol=tol, rtol=0) - @pytest.mark.parametrize("gamma1,gamma2", ((1.5, 0.0), (0.0, 1.0 + math.eps))) - def test_gamma_invalid_parameter(self, gamma1, gamma2): - """Ensures that error is thrown when gamma_1 or gamma_2 are outside [0,1]""" - with pytest.raises(ValueError, match="Each probability must be in the interval"): - channel.QutritAmplitudeDamping(gamma1, gamma2, wires=0).kraus_matrices() + @pytest.mark.parametrize( + "gamma1,gamma2,gamma3", + ( + (1.5, 0.0, 0.0), + (0.0, 1.0 + math.eps, 0.0), + (0.0, 0.0, 1.1), + (0.0, 0.33, 0.67 + math.eps), + ), + ) + def test_gamma_invalid_parameter(self, gamma1, gamma2, gamma3): + """Ensures that error is thrown when gamma_1, gamma_2, gamma_3, or (gamma_2 + gamma_3) are outside [0,1]""" + with pytest.raises(ValueError, match="must be in the interval"): + channel.QutritAmplitudeDamping(gamma1, gamma2, gamma3, wires=0).kraus_matrices() @staticmethod - def expected_jac_fn(gamma_1, gamma_2): + def expected_jac_fn(gamma_1, gamma_2, gamma_3): """Gets the expected Jacobian of Kraus matrices""" - partial_1 = [math.zeros((3, 3)) for _ in range(3)] + partial_1 = [math.zeros((3, 3)) for _ in range(4)] partial_1[0][1, 1] = -1 / (2 * math.sqrt(1 - gamma_1)) partial_1[1][0, 1] = 1 / (2 * math.sqrt(gamma_1)) - partial_2 = [math.zeros((3, 3)) for _ in range(3)] - partial_2[0][2, 2] = -1 / (2 * math.sqrt(1 - gamma_2)) + partial_2 = [math.zeros((3, 3)) for _ in range(4)] + partial_2[0][2, 2] = -1 / (2 * math.sqrt(1 - gamma_2 - gamma_3)) partial_2[2][0, 2] = 1 / (2 * math.sqrt(gamma_2)) - return [partial_1, partial_2] + partial_3 = [math.zeros((3, 3)) for _ in range(4)] + partial_3[0][2, 2] = -1 / (2 * math.sqrt(1 - gamma_2 - gamma_3)) + partial_3[3][1, 2] = 1 / (2 * math.sqrt(gamma_3)) + + return [partial_1, partial_2, partial_3] @staticmethod - def kraus_fn(gamma_1, gamma_2): + def kraus_fn(gamma_1, gamma_2, gamma_3): """Gets the Kraus matrices of QutritAmplitudeDamping channel, used for differentiation.""" - damping_channel = qml.QutritAmplitudeDamping(gamma_1, gamma_2, wires=0) + damping_channel = qml.QutritAmplitudeDamping(gamma_1, gamma_2, gamma_3, wires=0) return math.stack(damping_channel.kraus_matrices()) @pytest.mark.autograd @@ -226,8 +241,10 @@ def test_kraus_jac_autograd(self): """Tests Jacobian of Kraus matrices using autograd.""" gamma_1 = pnp.array(0.43, requires_grad=True) gamma_2 = pnp.array(0.12, requires_grad=True) - jac = qml.jacobian(self.kraus_fn)(gamma_1, gamma_2) - assert math.allclose(jac, self.expected_jac_fn(gamma_1, gamma_2)) + gamma_3 = pnp.array(0.35, requires_grad=True) + + jac = qml.jacobian(self.kraus_fn)(gamma_1, gamma_2, gamma_3) + assert math.allclose(jac, self.expected_jac_fn(gamma_1, gamma_2, gamma_3)) @pytest.mark.torch def test_kraus_jac_torch(self): @@ -236,11 +253,15 @@ def test_kraus_jac_torch(self): gamma_1 = torch.tensor(0.43, requires_grad=True) gamma_2 = torch.tensor(0.12, requires_grad=True) + gamma_3 = torch.tensor(0.35, requires_grad=True) - jac = torch.autograd.functional.jacobian(self.kraus_fn, (gamma_1, gamma_2)) - expected = self.expected_jac_fn(gamma_1.detach().numpy(), gamma_2.detach().numpy()) - assert math.allclose(jac[0].detach().numpy(), expected[0]) - assert math.allclose(jac[1].detach().numpy(), expected[1]) + jac = torch.autograd.functional.jacobian(self.kraus_fn, (gamma_1, gamma_2, gamma_3)) + expected = self.expected_jac_fn( + gamma_1.detach().numpy(), gamma_2.detach().numpy(), gamma_3.detach().numpy() + ) + + for res_partial, exp_partial in zip(jac, expected): + assert math.allclose(res_partial.detach().numpy(), exp_partial) @pytest.mark.tf def test_kraus_jac_tf(self): @@ -249,10 +270,12 @@ def test_kraus_jac_tf(self): gamma_1 = tf.Variable(0.43) gamma_2 = tf.Variable(0.12) + gamma_3 = tf.Variable(0.35) + with tf.GradientTape() as tape: - out = self.kraus_fn(gamma_1, gamma_2) - jac = tape.jacobian(out, (gamma_1, gamma_2)) - assert math.allclose(jac, self.expected_jac_fn(gamma_1, gamma_2)) + out = self.kraus_fn(gamma_1, gamma_2, gamma_3) + jac = tape.jacobian(out, (gamma_1, gamma_2, gamma_3)) + assert math.allclose(jac, self.expected_jac_fn(gamma_1, gamma_2, gamma_3)) @pytest.mark.jax def test_kraus_jac_jax(self): @@ -261,5 +284,7 @@ def test_kraus_jac_jax(self): gamma_1 = jax.numpy.array(0.43) gamma_2 = jax.numpy.array(0.12) - jac = jax.jacobian(self.kraus_fn, argnums=[0, 1])(gamma_1, gamma_2) - assert math.allclose(jac, self.expected_jac_fn(gamma_1, gamma_2)) + gamma_3 = jax.numpy.array(0.35) + + jac = jax.jacobian(self.kraus_fn, argnums=[0, 1, 2])(gamma_1, gamma_2, gamma_3) + assert math.allclose(jac, self.expected_jac_fn(gamma_1, gamma_2, gamma_3))