Documentation | Installation | Quick start | Examples | Contributing | Citing qujax
qujax is a JAX-based Python library for the classical simulation of quantum circuits. It is designed to be simple, fast and flexible.
It follows a functional programming design by translating circuits into pure functions. This allows qujax to seamlessly interface with JAX, enabling direct access to its powerful automatic differentiation tools, just-in-time compiler, vectorization capabilities, GPU/TPU integration and growing ecosystem of packages.
qujax can be used both for pure and for mixed quantum state simulation. It not only supports the standard gate set, but also allows user-defined custom operations, including general quantum channels, enabling the user to e.g. model device noise and errors.
A summary of the core functionalities of qujax can be found in the Quick start section. More advanced use-cases, including the training of parameterised quantum circuits, can be found in the Examples section of the documentation.
qujax is hosted on PyPI and can be installed via the pip package manager
pip install qujax
Important note: qujax circuit parameters are expressed in units of
Start by defining the quantum gates making up the circuit, the qubits that they act on, and the indices of the parameters for each gate.
A list of all gates can be found here (custom operations can be included by passing an array or function instead of a string).
from jax import numpy as jnp
import qujax
# List of quantum gates
circuit_gates = ['H', 'Ry', 'CZ']
# Indices of qubits the gates will be applied to
circuit_qubit_inds = [[0], [0], [0, 1]]
# Indices of parameters each parameterised gate will use
circuit_params_inds = [[], [0], []]
qujax.print_circuit(circuit_gates, circuit_qubit_inds, circuit_params_inds);
# q0: -----H-----Ry[0]-----◯---
# |
# q1: ---------------------CZ--
Translate the circuit to a pure function param_to_st
that takes a set of parameters and an (optional) initial quantum state as its input.
param_to_st = qujax.get_params_to_statetensor_func(circuit_gates,
circuit_qubit_inds,
circuit_params_inds)
param_to_st(jnp.array([0.1]))
# Array([[0.58778524+0.j, 0. +0.j],
# [0.80901706+0.j, 0. +0.j]], dtype=complex64)
The optional initial state can be passed to param_to_st
using the statetensor_in
argument. When it is not provided, the initial state defaults to
Map the state to an expectation value by defining an observable using lists of Pauli matrices, the qubits they act on, and the associated coefficients.
st_to_expectation = qujax.get_statetensor_to_expectation_func([['Z']], [[0]], [1.])
Combining param_to_st
and st_to_expectation
gives us a parameter to expectation function that can be automatically differentiated using JAX.
from jax import value_and_grad
param_to_expectation = lambda param: st_to_expectation(param_to_st(param))
expectation_and_grad = value_and_grad(param_to_expectation)
expectation_and_grad(jnp.array([0.1]))
# (Array(-0.3090171, dtype=float32),
# Array([-2.987832], dtype=float32))
Mixed state simulations are analogous to the above, but with calls to get_params_to_densitytensor_func
and get_densitytensor_to_expectation_func
instead.
A more in-depth version of the above can be found in the Getting started section of the documentation. More advanced use-cases, including the training of parameterised quantum circuits, can be found in the Examples section of the documentation.
A pytket
circuit can be directly converted using the tk_to_qujax
and tk_to_qujax_symbolic
functions in the pytket-qujax
extension. See pytket-qujax_heisenberg_vqe.ipynb
for an example.
You can open a bug report or a feature request by creating a new issue on GitHub.
Pull requests are welcome! To open a new one, please go through the following steps:
- First fork the repo and create your branch from
develop
. - Commit your code and tests.
- Update the documentation, if required.
- Check the code lints (run
black . --check
andpylint */
). - Issue a pull request into the
develop
branch.
New commits on develop
will be merged into
main
in the next release.
If you have used qujax in your code or research, we kindly ask that you cite it. You can use the following BibTeX entry for this:
@article{qujax2023,
author = {Duffield, Samuel and Matos, Gabriel and Johannsen, Melf},
doi = {10.21105/joss.05504},
journal = {Journal of Open Source Software},
month = sep,
number = {89},
pages = {5504},
title = {{qujax: Simulating quantum circuits with JAX}},
url = {https://joss.theoj.org/papers/10.21105/joss.05504},
volume = {8},
year = {2023}
}