Skip to content

Commit

Permalink
Port DAGCircuit to Rust
Browse files Browse the repository at this point in the history
This commit migrates the entirety of the `DAGCircuit` class to Rust. It
fully replaces the Python version of the class. The primary advantage
of this migration is moving from a Python space rustworkx directed graph
representation to a Rust space petgraph (the upstream library for
rustworkx) directed graph. Moving the graph data structure to rust
enables us to directly interact with the DAG directly from transpiler
passes in Rust in the future. This will enable a significant speed-up in
those transpiler passes. Additionally, this should also improve the
memory footprint as the DAGCircuit no longer stores `DAGNode`
instances, and instead stores a lighter enum NodeType, which simply
contains a `PackedInstruction` or the wire objects directly.

Internally, the new Rust-based `DAGCircuit` uses a `petgraph::StableGraph`
with node weights of type `NodeType` and edge weights of type `Wire`. The
NodeType enum contains variants for `QubitIn`, `QubitOut`, `ClbitIn`,
`ClbitOut`, and `Operation`, which should save us from all of the
`isinstance` checking previously needed when working with `DAGNode` Python
instances. The `Wire` enum contains variants `Qubit`, `Clbit`, and `Var`.

As the full Qiskit data model is not rust-native at this point while
all the class code in the `DAGCircuit` exists in Rust now, there are
still sections that rely on Python or actively run Python code via Rust
to function. These typically involve anything that uses `condition`,
control flow, classical vars, calibrations, bit/register manipulation,
etc. In the future as we either migrate this functionality to Rust or
deprecate and remove it this can be updated in place to avoid the use
of Python.

API access from Python-space remains in terms of `DAGNode` instances to
maintain API compatibility with the Python implementation. However,
internally, we convert to and deal in terms of NodeType. When the user
requests a particular node via lookup or iteration, we inflate an ephemeral
`DAGNode` based on the internal `NodeType` and give them that. This is very
similar to what was done in Qiskit#10827 when porting CircuitData to Rust.

As part of this porting there are a few small differences to keep in
mind with the new Rust implementation of DAGCircuit. The first is that
the topological ordering is slightly different with the new DAGCircuit.
Previously, the Python version of `DAGCircuit` using a lexicographical
topological sort key which was basically `"0,1,0,2"` where the first
`0,1` are qargs on qubit indices `0,1` for nodes and `0,2` are cargs
on clbit indices `0,2`. However, the sort key has now changed to be
`(&[Qubit(0), Qubit(1)], &[Clbit(0), Clbit(2)])` in rust in this case
which for the most part should behave identically, but there are some
edge cases that will appear where the sort order is different. It will
always be a valid topological ordering as the lexicographical key is
used as a tie breaker when generating a topological sort. But if you're
relaying on the exact same sort order there will be differences after
this PR. The second is that a lot of undocumented functionality in the
DAGCircuit which previously worked because of Python's implicit support
for interacting with data structures is no longer functional. For
example, previously the `DAGCircuit.qubits` list could be set directly
(as the circuit visualizers previously did), but this was never
documented as supported (and would corrupt the DAGCircuit). Any
functionality like this we'd have to explicit include in the Rust
implementation and as they were not included in the documented public
API this PR opted to remove the vast majority of this type of
functionality.

The last related thing might require future work to mitigate is that
this PR breaks the linkage between `DAGNode` and the underlying
`DAGCirucit` object. In the Python implementation the `DAGNode` objects
were stored directly in the `DAGCircuit` and when an API method returned
a `DAGNode` from the DAG it was a shared reference to the underlying
object in the `DAGCircuit`. This meant if you mutated the `DAGNode` it
would be reflected in the `DAGCircuit`. This was not always a sound
usage of the API as the `DAGCircuit` was implicitly caching many
attributes of the DAG and you should always be using the `DAGCircuit`
API to mutate any nodes to prevent any corruption of the `DAGCircuit`.
However, now as the underlying data store for nodes in the DAG are
no longer the python space objects returned by `DAGCircuit` methods
mutating a `DAGNode` will not make any change in the underlying
`DAGCircuit`. This can come as quite the surprise at first, especially
if you were relying on this side effect, even if it was unsound.

It's also worth noting that 2 large pieces of functionality from
rustworkx are included in this PR. These are the new files
`rustworkx_core_vnext` and `dot_utils` which are rustworkx's VF2
implementation and its dot file generation. As there was not a rust
interface exposed for this functionality from rustworkx-core there was
no way to use these functions in rustworkx. Until these interfaces
added to rustworkx-core in future releases we'll have to keep these
local copies. The vf2 implementation is in progress in
Qiskit/rustworkx#1235, but `dot_utils` might make sense to keep around
longer term as it is slightly modified from the upstream rustworkx
implementation to directly interface with `DAGCircuit` instead of a
generic graph.

Co-authored-by: Matthew Treinish <mtreinish@kortar.org>
Co-authored-by: Raynel Sanchez <87539502+raynelfss@users.noreply.github.com>
Co-authored-by: Elena Peña Tapia <57907331+ElePT@users.noreply.github.com>
Co-authored-by: Alexander Ivrii <alexi@il.ibm.com>
Co-authored-by: Eli Arbel <46826214+eliarbel@users.noreply.github.com>
Co-authored-by: John Lapeyre <jlapeyre@users.noreply.github.com>
Co-authored-by: Jake Lishman <jake.lishman@ibm.com>
  • Loading branch information
8 people committed Aug 12, 2024
1 parent c7e7016 commit c911205
Show file tree
Hide file tree
Showing 53 changed files with 8,901 additions and 2,978 deletions.
6 changes: 6 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@ license = "Apache-2.0"
[workspace.dependencies]
bytemuck = "1.16"
indexmap.version = "2.3.0"
hashbrown.version = "0.14.0"
hashbrown.version = "0.14.5"
num-bigint = "0.4"
num-complex = "0.4"
ndarray = "^0.15.6"
numpy = "0.21.0"
smallvec = "1.13"
thiserror = "1.0"
rustworkx-core = "0.15"
approx = "0.5"
itertools = "0.13.0"
ahash = "0.8.11"

# Most of the crates don't need the feature `extension-module`, since only `qiskit-pyext` builds an
Expand Down
6 changes: 3 additions & 3 deletions crates/accelerate/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ rand_distr = "0.4.3"
ahash.workspace = true
num-traits = "0.2"
num-complex.workspace = true
rustworkx-core.workspace = true
num-bigint.workspace = true
rustworkx-core = "0.15"
faer = "0.19.1"
itertools = "0.13.0"
itertools.workspace = true
qiskit-circuit.workspace = true
thiserror.workspace = true

Expand All @@ -38,7 +38,7 @@ workspace = true
features = ["rayon", "approx-0_5"]

[dependencies.approx]
version = "0.5"
workspace = true
features = ["num-complex"]

[dependencies.hashbrown]
Expand Down
26 changes: 2 additions & 24 deletions crates/accelerate/src/convert_2q_block_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ use qiskit_circuit::circuit_instruction::CircuitInstruction;
use qiskit_circuit::dag_node::DAGOpNode;
use qiskit_circuit::gate_matrix::ONE_QUBIT_IDENTITY;
use qiskit_circuit::imports::QI_OPERATOR;
use qiskit_circuit::operations::{Operation, OperationRef};
use qiskit_circuit::operations::Operation;

use crate::QiskitError;

fn get_matrix_from_inst<'py>(
py: Python<'py>,
inst: &'py CircuitInstruction,
) -> PyResult<Array2<Complex64>> {
if let Some(mat) = inst.op().matrix(&inst.params) {
if let Some(mat) = inst.operation.matrix(&inst.params) {
Ok(mat)
} else if inst.operation.try_standard_gate().is_some() {
Err(QiskitError::new_err(
Expand Down Expand Up @@ -124,29 +124,7 @@ pub fn change_basis(matrix: ArrayView2<Complex64>) -> Array2<Complex64> {
trans_matrix
}

#[pyfunction]
pub fn collect_2q_blocks_filter(node: &Bound<PyAny>) -> Option<bool> {
let Ok(node) = node.downcast::<DAGOpNode>() else {
return None;
};
let node = node.borrow();
match node.instruction.op() {
gate @ (OperationRef::Standard(_) | OperationRef::Gate(_)) => Some(
gate.num_qubits() <= 2
&& node
.instruction
.extra_attrs
.as_ref()
.and_then(|attrs| attrs.condition.as_ref())
.is_none()
&& !node.is_parameterized(),
),
_ => Some(false),
}
}

pub fn convert_2q_block_matrix(m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(blocks_to_matrix))?;
m.add_wrapped(wrap_pyfunction!(collect_2q_blocks_filter))?;
Ok(())
}
24 changes: 4 additions & 20 deletions crates/accelerate/src/euler_one_qubit_decomposer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ pub fn compute_error_list(
.iter()
.map(|node| {
(
node.instruction.op().name().to_string(),
node.instruction.operation.name().to_string(),
smallvec![], // Params not needed in this path
)
})
Expand Down Expand Up @@ -988,10 +988,11 @@ pub fn optimize_1q_gates_decomposition(
.iter()
.map(|node| {
if let Some(err_map) = error_map {
error *= compute_error_term(node.instruction.op().name(), err_map, qubit)
error *=
compute_error_term(node.instruction.operation.name(), err_map, qubit)
}
node.instruction
.op()
.operation
.matrix(&node.instruction.params)
.expect("No matrix defined for operation")
})
Expand Down Expand Up @@ -1043,22 +1044,6 @@ fn matmul_1q(operator: &mut [[Complex64; 2]; 2], other: Array2<Complex64>) {
];
}

#[pyfunction]
pub fn collect_1q_runs_filter(node: &Bound<PyAny>) -> bool {
let Ok(node) = node.downcast::<DAGOpNode>() else {
return false;
};
let node = node.borrow();
let op = node.instruction.op();
op.num_qubits() == 1
&& op.num_clbits() == 0
&& op.matrix(&node.instruction.params).is_some()
&& match &node.instruction.extra_attrs {
None => true,
Some(attrs) => attrs.condition.is_none(),
}
}

pub fn euler_one_qubit_decomposer(m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(params_zyz))?;
m.add_wrapped(wrap_pyfunction!(params_xyx))?;
Expand All @@ -1072,7 +1057,6 @@ pub fn euler_one_qubit_decomposer(m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(compute_error_one_qubit_sequence))?;
m.add_wrapped(wrap_pyfunction!(compute_error_list))?;
m.add_wrapped(wrap_pyfunction!(optimize_1q_gates_decomposition))?;
m.add_wrapped(wrap_pyfunction!(collect_1q_runs_filter))?;
m.add_class::<OneQubitGateSequence>()?;
m.add_class::<OneQubitGateErrorMap>()?;
m.add_class::<EulerBasis>()?;
Expand Down
14 changes: 13 additions & 1 deletion crates/circuit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,29 @@ name = "qiskit_circuit"
doctest = false

[dependencies]
rayon = "1.10"
ahash.workspace = true
rustworkx-core.workspace = true
bytemuck.workspace = true
hashbrown.workspace = true
num-complex.workspace = true
ndarray.workspace = true
numpy.workspace = true
thiserror.workspace = true
approx.workspace = true
itertools.workspace = true

[dependencies.pyo3]
workspace = true
features = ["hashbrown", "indexmap", "num-complex", "num-bigint", "smallvec"]

[dependencies.hashbrown]
workspace = true
features = ["rayon"]

[dependencies.indexmap]
workspace = true
features = ["rayon"]

[dependencies.smallvec]
workspace = true
features = ["union"]
Expand Down
80 changes: 66 additions & 14 deletions crates/circuit/src/bit_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use pyo3::prelude::*;
use pyo3::types::PyList;
use std::fmt::Debug;
use std::hash::{Hash, Hasher};
use std::mem::swap;

/// Private wrapper for Python-side Bit instances that implements
/// [Hash] and [Eq], allowing them to be used in Rust hash-based
Expand Down Expand Up @@ -81,17 +82,6 @@ pub struct BitData<T> {
cached: Py<PyList>,
}

pub struct BitNotFoundError<'py>(pub(crate) Bound<'py, PyAny>);

impl<'py> From<BitNotFoundError<'py>> for PyErr {
fn from(error: BitNotFoundError) -> Self {
PyKeyError::new_err(format!(
"Bit {:?} has not been added to this circuit.",
error.0
))
}
}

impl<T> BitData<T>
where
T: From<BitType> + Copy,
Expand Down Expand Up @@ -139,14 +129,19 @@ where
pub fn map_bits<'py>(
&self,
bits: impl IntoIterator<Item = Bound<'py, PyAny>>,
) -> Result<impl Iterator<Item = T>, BitNotFoundError<'py>> {
) -> PyResult<impl Iterator<Item = T>> {
let v: Result<Vec<_>, _> = bits
.into_iter()
.map(|b| {
self.indices
.get(&BitAsKey::new(&b))
.copied()
.ok_or_else(|| BitNotFoundError(b))
.ok_or_else(|| {
PyKeyError::new_err(format!(
"Bit {:?} has not been added to this circuit.",
b
))
})
})
.collect();
v.map(|x| x.into_iter())
Expand All @@ -168,7 +163,7 @@ where
}

/// Adds a new Python bit.
pub fn add(&mut self, py: Python, bit: &Bound<PyAny>, strict: bool) -> PyResult<()> {
pub fn add(&mut self, py: Python, bit: &Bound<PyAny>, strict: bool) -> PyResult<T> {
if self.bits.len() != self.cached.bind(bit.py()).len() {
return Err(PyRuntimeError::new_err(
format!("This circuit's {} list has become out of sync with the circuit data. Did something modify it?", self.description)
Expand All @@ -193,6 +188,29 @@ where
bit
)));
}
Ok(idx.into())
}

pub fn remove_indices<I>(&mut self, py: Python, indices: I) -> PyResult<()>
where
I: IntoIterator<Item = T>,
{
let mut indices_sorted: Vec<usize> = indices
.into_iter()
.map(|i| <BitType as From<T>>::from(i) as usize)
.collect();
indices_sorted.sort();

for index in indices_sorted.into_iter().rev() {
self.cached.bind(py).del_item(index)?;
let bit = self.bits.remove(index);
self.indices.remove(&BitAsKey::new(bit.bind(py)));
}
// Update indices.
for (i, bit) in self.bits.iter().enumerate() {
self.indices
.insert(BitAsKey::new(bit.bind(py)), (i as BitType).into());
}
Ok(())
}

Expand All @@ -203,3 +221,37 @@ where
self.bits.clear();
}
}

pub struct Iter<'a, T> {
_data: &'a BitData<T>,
index: usize,
}

impl<'a, T> Iterator for Iter<'a, T>
where
T: From<BitType>,
{
type Item = T;

fn next(&mut self) -> Option<Self::Item> {
let mut index = self.index + 1;
swap(&mut self.index, &mut index);
let index: Option<BitType> = index.try_into().ok();
index.map(|i| From::from(i))
}
}

impl<'a, T> IntoIterator for &'a BitData<T>
where
T: From<BitType>,
{
type Item = T;
type IntoIter = Iter<'a, T>;

fn into_iter(self) -> Self::IntoIter {
Iter {
_data: self,
index: 0,
}
}
}
Loading

0 comments on commit c911205

Please sign in to comment.