diff --git a/src/qiskit_qec/circuits/__init__.py b/src/qiskit_qec/circuits/__init__.py index 23bd413b..433410af 100644 --- a/src/qiskit_qec/circuits/__init__.py +++ b/src/qiskit_qec/circuits/__init__.py @@ -34,3 +34,4 @@ from .repetition_code import RepetitionCodeCircuit, ArcCircuit from .surface_code import SurfaceCodeCircuit from .css_code import CSSCodeCircuit +from .stim_code_circuit import StimCodeCircuit diff --git a/src/qiskit_qec/circuits/code_circuit.py b/src/qiskit_qec/circuits/code_circuit.py index 916e9698..714b27fc 100644 --- a/src/qiskit_qec/circuits/code_circuit.py +++ b/src/qiskit_qec/circuits/code_circuit.py @@ -52,6 +52,14 @@ def string2nodes(self, string, **kwargs): """ pass + @abstractmethod + def measured_logicals(self): + """ + Returns a list of logical operators, each expressed as a list of qubits for which + the parity of the final readouts corresponds to the raw logical readout. + """ + pass + @abstractmethod def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): """ diff --git a/src/qiskit_qec/circuits/css_code.py b/src/qiskit_qec/circuits/css_code.py index e95cef40..3f94ba75 100644 --- a/src/qiskit_qec/circuits/css_code.py +++ b/src/qiskit_qec/circuits/css_code.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- + # This code is part of Qiskit. # # (C) Copyright IBM 2019. @@ -11,21 +12,19 @@ # copyright notice, and modified files need to carry a notice indicating # that they have been altered from the originals. -"""Class that manage circuits for CSS codes.""" -# pylint: disable=invalid-name +# pylint: disable=invalid-name, disable=no-name-in-module + +"""Generates circuits for CSS codes.""" from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister from qiskit_aer.noise import depolarizing_error, pauli_error -# pylint: disable=no-name-in-module -from stim import Circuit as StimCircuit -from stim import target_rec as StimTarget_rec - -from qiskit_qec.utils import DecodingGraphNode from qiskit_qec.circuits.code_circuit import CodeCircuit from qiskit_qec.utils.stim_tools import ( noisify_circuit, get_stim_circuits, detector_error_model_to_rx_graph, + string2nodes_with_detectors, + string2logical_meas, ) from qiskit_qec.codes import StabSubSystemCode from qiskit_qec.operators.pauli_list import PauliList @@ -34,31 +33,37 @@ class CSSCodeCircuit(CodeCircuit): - """CodeCircuit class for generic CSS codes.""" + """ + CodeCircuit class for generic CSS codes. + """ def __init__( - self, code, T: int, basis: str = "z", round_schedule: str = "zx", noise_model=None + self, + code, + T: int, + basis: str = "z", + round_schedule: str = "zx", + noise_model=None, ): - """CSSCodeCircuit init method - + """ Args: code: A CSS code class which is either - a) StabSubSystemCode - b) a class with the following methods: - 'x_gauges' (as a list of list of qubit indices), - 'z_gauges', - 'x_stabilizers', - 'z_stabilizers', - 'logical_x', - 'logical_z', - 'n' (number of qubits), + a) StabSubSystemCode + b) a class with the following methods: + 'x_gauges' (as a list of list of qubit indices), + 'z_gauges', + 'x_stabilizers', + 'z_stabilizers', + 'logical_x', + 'logical_z', + 'n' (number of qubits), T: Number of syndrome measurement rounds basis: basis for encoding ('x' or 'z') round_schedule: Order in which to measureme gauge operators ('zx' or 'xz') noise_model: Pauli noise model used in the construction of noisy circuits. - If a tuple, a pnenomological noise model is used with the entries being - probabity of depolarizing noise on code qubits between rounds and - probability of measurement errors, respectively. + If a tuple, a pnenomological noise model is used with the entries being + probabity of depolarizing noise on code qubits between rounds and + probability of measurement errors, respectively. Examples: The QuantumCircuit of a memory experiment for the distance-3 HeavyHEX code @@ -66,8 +71,8 @@ def __init__( >>> from qiskit_qec.circuits.css_code import CSSCodeCircuit >>> code = CSSCodeCircuit(HHC(3),T=3,basis='x',noise_model=(0.01,0.01),round_schedule='xz') >>> code.circuit['0'] - """ + super().__init__() self.code = code @@ -124,6 +129,7 @@ def __init__( if set(stabilizer).intersection(set(gauge)) == set(gauge): gauges.append(g) self._gauges4stabilizers[j].append(gauges) + self.detectors, self.logicals = self.stim_detectors() def _get_code_properties(self): if isinstance(self.code, StabSubSystemCode): @@ -138,10 +144,10 @@ def _get_code_properties(self): stabilizers = [[], []] logicals = [[], []] - for ( - raw_ops, - ops, - ) in zip([raw_gauges, raw_stabilizers, raw_logicals], [gauges, stabilizers, logicals]): + for (raw_ops, ops,) in zip( + [raw_gauges, raw_stabilizers, raw_logicals], + [gauges, stabilizers, logicals], + ): for op in raw_ops: op = str(op) for j, pauli in enumerate(["X", "Z"]): @@ -177,6 +183,16 @@ def _get_code_properties(self): self.z_stabilizers = self.code.z_stabilizers self.logical_x = self.code.logical_x self.logical_z = self.code.logical_z + # for the unionfind decoder + self.css_x_logical = self.logical_x + self.css_z_logical = self.logical_z + + def measured_logicals(self): + if self.basis == "x": + measured_logicals = self.logical_x + else: + measured_logicals = self.logical_z + return measured_logicals def _prepare_initial_state(self, qc, qregs, state): if state[0] == "1": @@ -247,98 +263,31 @@ def string2nodes(self, string, **kwargs): logical (str): Logical value whose results are used ('0' as default). all_logicals (bool): Whether to include logical nodes irrespective of value. (False as default). - """ - all_logicals = kwargs.get("all_logicals") - logical = kwargs.get("logical") - if logical is None: - logical = "0" + Returns: + nodes: a list of 'DecodingGraphNode()'s corresponding to the triggered detectors + """ - output = string.split(" ")[::-1] - gauge_outs = [[], []] - for t in range(self.T): - gauge_outs[0].append( - [int(b) for b in output[2 * t + self.round_schedule.find("x")]][::-1] - ) - gauge_outs[1].append( - [int(b) for b in output[2 * t + self.round_schedule.find("z")]][::-1] - ) - final_outs = [int(b) for b in output[-1]] + nodes = string2nodes_with_detectors( + string=string, + detectors=self.detectors, + logicals=self.logicals, + clbits=self.circuit["0"].clbits, + det_ref_values=0, + **kwargs, + ) + return nodes - stabilizer_outs = [] - for j in range(2): - stabilizer_outs.append([]) - for t in range(self.T): - round_outs = [] - for gs in self._gauges4stabilizers[j]: - out = 0 - for g in gs: - out += gauge_outs[j][t][g] - out = out % 2 - round_outs.append(out) - stabilizer_outs[j].append(round_outs) - - bases = ["x", "z"] - j = bases.index(self.basis) - final_gauges = [] - for gauge in self._gauges[j]: - out = 0 - for q in gauge: - out += final_outs[-q - 1] - out = out % 2 - final_gauges.append(out) - final_stabilizers = [] - for gs in self._gauges4stabilizers[j]: - out = 0 - for g in gs: - out += final_gauges[g] - out = out % 2 - final_stabilizers.append(out) - stabilizer_outs[j].append(final_stabilizers) - - stabilizer_changes = [] - for j in range(2): - stabilizer_changes.append([]) - for t in range(self.T + (bases[j] == self.basis)): - stabilizer_changes[j].append([]) - for e in range(len(stabilizer_outs[j][t])): - if t == 0 and j == bases.index(self.basis): - stabilizer_changes[j][t].append(stabilizer_outs[j][t][e]) - else: - stabilizer_changes[j][t].append( - (stabilizer_outs[j][t][e] + stabilizer_outs[j][t - 1][e]) % 2 - ) - - nodes = [] - for j in range(2): - for t, round_changes in enumerate(stabilizer_changes[j]): - for e, change in enumerate(round_changes): - if change == 1: - node = DecodingGraphNode(time=t, qubits=self._stabilizers[j][e], index=e) - node.properties["basis"] = bases[j] - nodes.append(node) + def string2raw_logicals(self, string): + """ + Converts output string into a list of logical measurement outcomes + Logicals are the logical measurements produced by self.stim_detectors() + """ - if self.basis == "x": - logicals = self.logical_x - else: - logicals = self.logical_z - - for index, logical_op in enumerate(logicals): - logical_out = 0 - for q in logical_op: - logical_out += final_outs[-q - 1] - logical_out = logical_out % 2 - - if all_logicals or str(logical_out) != logical: - node = DecodingGraphNode( - is_boundary=True, - qubits=logical, - index=index, - ) - node.properties["basis"] = self.basis - nodes.append(node) + _, self.logicals = self.stim_detectors() - return nodes + log_outs = string2logical_meas(string, self.logicals, self.circuit["0"].clbits) + return log_outs def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): raise NotImplementedError @@ -346,145 +295,116 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): def is_cluster_neutral(self, atypical_nodes): raise NotImplementedError - def stim_circuit_with_detectors(self): - """Converts the qiskit circuits into stim ciruits and add detectors. - This is required for the stim-based construction of the DecodingGraph. + def stim_detectors(self): + """ + Constructs detectors and logicals required for stim. + + Returns: + detectors (list[dict]): dictionaries containing + a) 'clbits', the classical bits (register, index) included in the measurement comparisons + b) 'qubits', the qubits (list of indices) participating in the stabilizer measurements + c) 'time', measurement round (int) of the earlier measurements in the detector + d) 'basis', the pauli basis ('x' or 'z') of the stabilizers + logicals (list[dict]): dictionaries containing + a) 'clbits', the classical bits (register, index) included in the logical measurement + b) 'basis', the pauli basis ('x' or 'z') of the logical """ - stim_circuits, _ = get_stim_circuits(self.noisy_circuit) - measurements_per_cycle = len(self.x_gauges) + len(self.z_gauges) - if self.round_schedule[0] == "x": - measurement_round_offset = [0, len(self.x_gauges)] - else: - measurement_round_offset = [len(self.z_gauges), 0] + detectors = [] + logicals = [] ## 0th round of measurements if self.basis == "x": + reg = "round_0_x_bits" for stabind, stabilizer in enumerate(self.x_stabilizers): - record_targets = [] + det = {"clbits": []} for gauge_ind in self._gauges4stabilizers[0][stabind]: - record_targets.append( - StimTarget_rec( - measurement_round_offset[0] - + gauge_ind - - (self.T * measurements_per_cycle + self.code.n) - ) - ) - qubits_and_time = stabilizer.copy() - qubits_and_time.extend([0]) - stim_circuits["0"].append("DETECTOR", record_targets, qubits_and_time) - stim_circuits["1"].append("DETECTOR", record_targets, qubits_and_time) + det["clbits"].append((reg, gauge_ind)) + det["qubits"] = stabilizer.copy() + det["time"] = 0 + det["basis"] = "x" + detectors.append(det) + else: + reg = "round_0_z_bits" for stabind, stabilizer in enumerate(self.z_stabilizers): - record_targets = [] + det = {"clbits": []} for gauge_ind in self._gauges4stabilizers[1][stabind]: - record_targets.append( - StimTarget_rec( - measurement_round_offset[1] - + gauge_ind - - (self.T * measurements_per_cycle + self.code.n) - ) - ) - qubits_and_time = stabilizer.copy() - qubits_and_time.extend([0]) - stim_circuits["0"].append("DETECTOR", record_targets, qubits_and_time) - stim_circuits["1"].append("DETECTOR", record_targets, qubits_and_time) + det["clbits"].append((reg, gauge_ind)) + det["qubits"] = stabilizer.copy() + det["time"] = 0 + det["basis"] = "z" + detectors.append(det) # adding first x and then z stabilizer comparisons - for j in range(2): - circuit = StimCircuit() + for j, basis in enumerate(["x", "z"]): for t in range( 1, self.T ): # compare stabilizer measurements with previous in each round + reg_prev = "round_" + str(t - 1) + "_" + basis + "_bits" + reg_t = "round_" + str(t) + "_" + basis + "_bits" for gind, gs in enumerate(self._gauges4stabilizers[j]): - record_targets = [] + det = {"clbits": []} for gauge_ind in gs: - record_targets.append( - StimTarget_rec( - t * measurements_per_cycle - + measurement_round_offset[j] - + gauge_ind - - (self.T * measurements_per_cycle + self.code.n) - ) - ) - record_targets.append( - StimTarget_rec( - (t - 1) * measurements_per_cycle - + measurement_round_offset[j] - + gauge_ind - - (self.T * measurements_per_cycle + self.code.n) - ) - ) - qubits_and_time = self._stabilizers[j][gind].copy() - qubits_and_time.extend([t]) - circuit.append("DETECTOR", record_targets, qubits_and_time) - stim_circuits["0"] += circuit - stim_circuits["1"] += circuit + det["clbits"].append((reg_t, gauge_ind)) + det["clbits"].append((reg_prev, gauge_ind)) + det["qubits"] = self._stabilizers[j][gind].copy() + det["time"] = t + det["basis"] = basis + detectors.append(det) ## final measurements if self.basis == "x": + reg_prev = "round_" + str(self.T - 1) + "_x_bits" + reg_T = "final_readout" for stabind, stabilizer in enumerate(self.x_stabilizers): - record_targets = [] + det = {"clbits": []} for q in stabilizer: - record_targets.append(StimTarget_rec(q - self.code.n)) + det["clbits"].append((reg_T, q)) for gauge_ind in self._gauges4stabilizers[0][stabind]: - record_targets.append( - StimTarget_rec( - measurement_round_offset[0] - + gauge_ind - - self.code.n - - measurements_per_cycle - ) - ) - qubits_and_time = stabilizer.copy() - qubits_and_time.extend([self.T]) - stim_circuits["0"].append("DETECTOR", record_targets, qubits_and_time) - stim_circuits["1"].append("DETECTOR", record_targets, qubits_and_time) - stim_circuits["0"].append( - "OBSERVABLE_INCLUDE", - [StimTarget_rec(q - self.code.n) for q in sorted(self.logical_x[0])], - 0, - ) - stim_circuits["1"].append( - "OBSERVABLE_INCLUDE", - [StimTarget_rec(q - self.code.n) for q in sorted(self.logical_x[0])], - 0, + det["clbits"].append((reg_prev, gauge_ind)) + det["qubits"] = stabilizer.copy() + det["time"] = self.T + det["basis"] = "x" + detectors.append(det) + logicals.append( + { + "clbits": [(reg_T, q) for q in sorted(self.logical_x[0])], + "basis": "z", + } ) else: + reg_prev = "round_" + str(self.T - 1) + "_z_bits" + reg_T = "final_readout" for stabind, stabilizer in enumerate(self.z_stabilizers): - record_targets = [] + det = {"clbits": []} for q in stabilizer: - record_targets.append(StimTarget_rec(q - self.code.n)) + det["clbits"].append((reg_T, q)) for gauge_ind in self._gauges4stabilizers[1][stabind]: - record_targets.append( - StimTarget_rec( - measurement_round_offset[1] - + gauge_ind - - self.code.n - - measurements_per_cycle - ) - ) - qubits_and_time = stabilizer.copy() - qubits_and_time.extend([self.T]) - stim_circuits["0"].append("DETECTOR", record_targets, qubits_and_time) - stim_circuits["1"].append("DETECTOR", record_targets, qubits_and_time) - stim_circuits["0"].append( - "OBSERVABLE_INCLUDE", - [StimTarget_rec(q - self.code.n) for q in sorted(self.logical_z[0])], - 0, - ) - stim_circuits["1"].append( - "OBSERVABLE_INCLUDE", - [StimTarget_rec(q - self.code.n) for q in sorted(self.logical_z[0])], - 0, + det["clbits"].append((reg_prev, gauge_ind)) + det["qubits"] = stabilizer.copy() + det["time"] = self.T + det["basis"] = "x" + detectors.append(det) + logicals.append( + { + "clbits": [(reg_T, q) for q in sorted(self.logical_z[0])], + "basis": "z", + } ) - return stim_circuits + return detectors, logicals def _make_syndrome_graph(self): - stim_circuit = self.stim_circuit_with_detectors()["0"] + """ + Used by the DecodingGraph class to build the decoding graph and the obtain hyperedges + """ + detectors, logicals = self.stim_detectors() + stim_circuit = get_stim_circuits( + self.noisy_circuit["0"], detectors=detectors, logicals=logicals + )[0][0] e = stim_circuit.detector_error_model( decompose_errors=True, approximate_disjoint_errors=True ) - graph, hyperedges = detector_error_model_to_rx_graph(e) + graph, hyperedges = detector_error_model_to_rx_graph(e, detectors=detectors) return graph, hyperedges diff --git a/src/qiskit_qec/circuits/repetition_code.py b/src/qiskit_qec/circuits/repetition_code.py index 0b621d6a..4ae904d2 100644 --- a/src/qiskit_qec/circuits/repetition_code.py +++ b/src/qiskit_qec/circuits/repetition_code.py @@ -243,6 +243,9 @@ def readout(self): self.circuit[log].add_register(self.code_bit) self.circuit[log].measure(self.code_qubit, self.code_bit) + def measured_logicals(self): + return [[0]] + def _process_string(self, string): # logical readout taken from measured_log = string[0] + " " + string[self.d - 1] @@ -965,6 +968,9 @@ def _readout(self): qc.add_register(self.code_bit) qc.measure(self.code_qubit, self.code_bit) + def measured_logicals(self): + return [[self.z_logicals[0]]] + def _process_string(self, string): # logical readout taken from assigned qubits measured_log = "" diff --git a/src/qiskit_qec/circuits/stim_code_circuit.py b/src/qiskit_qec/circuits/stim_code_circuit.py new file mode 100644 index 00000000..888fc9d3 --- /dev/null +++ b/src/qiskit_qec/circuits/stim_code_circuit.py @@ -0,0 +1,593 @@ +# -*- coding: utf-8 -*- + +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +# pylint: disable=invalid-name, disable=no-name-in-module, disable=no-member + +"""Generates CodeCircuits from stim circuits""" +import warnings + +from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister + +from qiskit.circuit.library.standard_gates import ( + IGate, + XGate, + YGate, + ZGate, + HGate, + SGate, + SdgGate, + CXGate, + CYGate, + CZGate, + SwapGate, +) + +from stim import CircuitInstruction, CircuitRepeatBlock, target_inv +from stim import Circuit as StimCircuit + +from qiskit_qec.circuits.code_circuit import CodeCircuit +from qiskit_qec.utils.stim_tools import ( + detector_error_model_to_rx_graph, + string2nodes_with_detectors, + string2logical_meas, +) + + +class StimCodeCircuit(CodeCircuit): + """ + Prepares a CodeCircuit class based on the supplied stim circuit. + """ + + def __init__( + self, + stim_circuit: StimCircuit, + barriers: bool = True, + ): + """ + Args: + stim_circuit: stim circuit to be coverted + barriers (optional): whether to include barriers (coeersponding to stim TICK instructions) + in the qiskit circuits. Default is True + Examples: + Prepare and measure a Bell-state, checking the parity with a measurement + comparison (DETECTOR) + >>> from qiskit_qec.circuits.stim_code_circuit import StimCodeCircuit + >>> stim_ex1 = stim.Circuit(''' + >>> H 0 + >>> TICK + >>> CX 0 1 + >>> X_ERROR(0.2) 0 1 + >>> TICK + >>> M 0 1 + >>> DETECTOR rec[-1] rec[-2] + >>> ''') + >>> stim_code = StimCodeCircuit(stim_circuit = stim_ex1) + >>> stim_code.qc + """ + super().__init__() + self.stim_circuit = stim_circuit + + self.measurement_data = [] + self.detectors = [] + self.logicals = [] + + self.qc = QuantumCircuit() + self.qc.add_register(QuantumRegister(self.stim_circuit.num_qubits)) + + single_qubit_gate_dict = { + "I": IGate(), + "X": XGate(), + "Y": YGate(), + "Z": ZGate(), + "H": HGate(), + "S": SGate(), + "S_DAG": SdgGate(), + } + + two_qubit_gate_dict = { + "CX": CXGate(), + "CY": CYGate(), + "CZ": CZGate(), + "SWAP": SwapGate(), + } + + def _helper(stim_circuit: StimCircuit, reps: int): + nonlocal rep_block_count + nonlocal block_count + for rep_ind in range(reps): + meas_count = 0 + for instruction in stim_circuit: + if isinstance(instruction, CircuitRepeatBlock): + _helper(instruction.body_copy(), instruction.repeat_count) + rep_block_count += 1 + elif isinstance(instruction, CircuitInstruction): + inst_name = instruction.name + if inst_name in single_qubit_gate_dict: + qubits = [target.value for target in instruction.targets_copy()] + for q in qubits: + self.qc.append(single_qubit_gate_dict[inst_name], qargs=[q]) + elif inst_name in two_qubit_gate_dict: + stim_targets = instruction.targets_copy() + for t1, t2 in zip(stim_targets[::2], stim_targets[1::2]): + if t1.is_qubit_target: + q1, q2 = t1.value, t2.value + self.qc.append(two_qubit_gate_dict[inst_name], qargs=[q1, q2]) + elif t1.is_measurement_record_target: + c1, q2 = t1.value, t2.value + inst_name = inst_name[1:] # remove the C from CX,CY,CZ + self.qc.append( + single_qubit_gate_dict[inst_name], qargs=[q2] + ).c_if(c1, 1) + + elif inst_name == "M": + qubits = [target.value for target in instruction.targets_copy()] + invert_result = [ + target.is_inverted_result_target + for target in instruction.targets_copy() + ] + if reps > 1: + cl_reg_name = ( + "rep_block_" + + str(rep_block_count) + + "_rep_" + + str(rep_ind) + + "_meas_block_" + + str(meas_count) + ) + else: + cl_reg_name = ( + "block_" + + str(rep_block_count) + + "_meas_block_" + + str(meas_count) + ) + creg = ClassicalRegister(len(qubits), name=cl_reg_name) + self.qc.add_register(creg) + flip_qubits = [q for q, inv in zip(qubits, invert_result) if inv] + if flip_qubits != []: + self.qc.x(flip_qubits) + self.qc.measure(qubits, creg) + for i, q in enumerate(qubits): + self.measurement_data.append((cl_reg_name, i)) + if flip_qubits != []: + self.qc.x(flip_qubits) + meas_count += 1 + elif inst_name == "R": + qubits = [target.value for target in instruction.targets_copy()] + self.qc.reset(qubits) + + elif inst_name == "TICK" and barriers: + self.qc.barrier() + + elif inst_name == "DETECTOR": + meas_targets = [t.value for t in instruction.targets_copy()] + self.detectors.append( + { + "clbits": [self.measurement_data[ind] for ind in meas_targets], + "time": [], + "qubits": [], + "stim_coords": instruction.gate_args_copy(), + } + ) + elif inst_name == "OBSERVABLE_INCLUDE": + meas_targets = [t.value for t in instruction.targets_copy()] + self.logicals.append( + {"clbits": [self.measurement_data[ind] for ind in meas_targets]} + ) + + rep_block_count = 0 + block_count = 0 + self.decomp_stim_circuit = self.decompose_stim_circuit(self.stim_circuit) + _helper(self.decomp_stim_circuit, 1) + + self.circuit = self.qc + + # if a set of measurement comparisons is deterministically 1 in the absence of errors, + # the set of syndromes is compared to that + noiseless_measurements = self.decomp_stim_circuit.compile_sampler().sample(1)[0] + clbit_dict = { + (clbit._register.name, clbit._index): clind + for clind, clbit in enumerate(self.qc.clbits) + } + detector_meas_indices = [ + [clbit_dict[clbit] for clbit in det["clbits"]] for det in self.detectors + ] + self.det_ref_values = [ + sum(noiseless_measurements[meas_list]) % 2 for meas_list in detector_meas_indices + ] + + # further code parameters + try: + self.d = len(self.stim_circuit.shortest_graphlike_error()) # code distance + except ValueError: + self.d = 0 + self.n = stim_circuit.num_qubits + # the number of rounds is not necessarily well-defined (Floquet codes etc.) + + def decompose_stim_circuit(self, stim_circuit): + """ + Decompose gates in the stim circuit into Clifford gates that have an equivalent qiskit gate. + Errors are not included. + """ + decompose_1_dict = { + # single-qubit gates + "C_XYZ": [("S_DAG", [0]), ("H", [0])], + "C_ZYX": [("H", [0]), ("S", [0])], + "H_XY": [("X", [0]), ("S", [0])], + "H_XZ": [("H", [0])], + "H_YZ": [("H", [0]), ("S", [0]), ("H", [0]), ("Z", [0])], + "SQRT_X": [("H", [0]), ("S", [0]), ("H", [0])], + "SQRT_X_DAG": [("S", [0]), ("H", [0]), ("S", [0])], + "SQRT_Y": [("Z", [0]), ("H", [0])], + "SQRT_Y_DAG": [("H", [0]), ("Z", [0])], + "SQRT_Z": [("S", [0])], + "SQRT_Z_DAG": [("S_DAG", [0])], + } + + decompose_2_dict = { + # two-qubit gates + "CNOT": [("CX", [0, 1])], + "CXSWAP": [("CX", [0, 1]), ("SWAP", [0, 1])], + "SWAPCX": [("SWAP", [0, 1]), ("CX", [0, 1])], + "ISWAP": [ + ("H", [0]), + ("SWAP", [0, 1]), + ("CX", [0, 1]), + ("S", [0]), + ("H", [1]), + ("S", [1]), + ], + "ISWAP_DAG": [ + ("S_DAG", [0]), + ("S_DAG", [1]), + ("H", [1]), + ("CX", [0, 1]), + ("SWAP", [0, 1]), + ("H", [0]), + ], + "SQRT_XX": [ + ("H", [0]), + ("CX", [0, 1]), + ("S", [0]), + ("H", [0]), + ("H", [1]), + ("S", [1]), + ("H", [1]), + ], + "SQRT_XX_DAG": [ + ("H", [0]), + ("CX", [0, 1]), + ("S_DAG", [0]), + ("H", [0]), + ("H", [1]), + ("S_DAG", [1]), + ("H", [1]), + ], + "SQRT_YY": [ + ("S_DAG", [0]), + ("H", [0]), + ("S_DAG", [1]), + ("CX", [0, 1]), + ("S", [0]), + ("H", [0]), + ("S", [0]), + ("H", [1]), + ("S", [1]), + ("H", [1]), + ("S", [1]), + ], + "SQRT_YY_DAG": [ + ("S_DAG", [0]), + ("H", [0]), + ("S", [1]), + ("CX", [0, 1]), + ("S", [0]), + ("H", [0]), + ("S", [0]), + ("H", [1]), + ("S", [1]), + ("H", [1]), + ("S_DAG", [1]), + ], + "SQRT_ZZ": [("H", [1]), ("CX", [0, 1]), ("S", [0]), ("H", [1]), ("S", [1])], + "SQRT_ZZ_DAG": [ + ("H", [1]), + ("CX", [0, 1]), + ("S_DAG", [0]), + ("H", [1]), + ("S_DAG", [1]), + ], + "XCX": [("H", [0]), ("CX", [0, 1]), ("H", [0])], + "XCY": [("H", [0]), ("CY", [0, 1]), ("H", [0])], + "XCZ": [("H", [0]), ("CZ", [0, 1]), ("S", [0])], + "YCX": [("S_DAG", [0]), ("H", [0]), ("CX", [0, 1]), ("H", [0]), ("S", [0])], + "YCY": [("S_DAG", [0]), ("H", [0]), ("CY", [0, 1]), ("H", [0]), ("S", [0])], + "YCZ": [("S_DAG", [0]), ("H", [0]), ("CZ", [0, 1]), ("H", [0]), ("S", [0])], + "ZCX": [("CX", [0, 1])], + "ZCY": [("CY", [0, 1])], + "ZCZ": [("CZ", [0, 1])], + } + decompose_m_dict = { + # measurements + "MX": [("H", [0]), ("M", [0]), ("H", [0])], + "MY": [ + ("S_DAG", [0]), + ("H", [0]), + ("M", [0]), + ("H", [0]), + ("S", [0]), + ], + "MZ": [("M", [0])], + "RX": [("H", [0]), ("R", [0]), ("H", [0])], + "RY": [ + ("S_DAG", [0]), + ("H", [0]), + ("R", [0]), + ("H", [0]), + ("S", [0]), + ], + "RZ": [("R", [0])], + "MRX": [("H", [0]), ("M", [0]), ("R", [0]), ("H", [0])], + "MRY": [ + ("S_DAG", [0]), + ("H", [0]), + ("M", [0]), + ("R", [0]), + ("H", [0]), + ("S", [0]), + ], + "MRZ": [("M", [0]), ("R", [0])], + "MR": [("M", [0]), ("R", [0])], + "MXX": [("CX", [0, 1]), ("H", [0]), ("M", [0]), ("H", [0]), ("CX", [0, 1])], + "MYY": [ + ("S_DAG", [0]), + ("S_DAG", [1]), + ("CX", [0, 1]), + ("H", [0]), + ("M", [0]), + ("H", [0]), + ("CX", [0, 1]), + ("S", [0]), + ("S", [1]), + ], + "MZZ": [("CX", [0, 1]), ("M", [1]), ("CX", [0, 1])], + } + + stim_error_list = [ + "CORRELATED_ERROR", + "DEPOLARIZE1", + "DEPOLARIZE2", + "E", + "ELSE_CORRELATED_ERROR", + "HERALDED_ERASE", + "HERALDED_PAULI_CHANNEL_1", + "PAULI_CHANNEL_1", + "PAULI_CHANNEL_2", + "X_ERROR", + "Y_ERROR", + "Z_ERROR", + ] + + decomposed_stim_circuit = StimCircuit() + for instruction in stim_circuit: + if isinstance(instruction, CircuitInstruction): + if instruction.name in decompose_1_dict: + targets = instruction.targets_copy() + decomp_gate = StimCircuit() + basis_gates = decompose_1_dict[instruction.name] + for target in targets: + for gate, _ in basis_gates: + decomp_gate.append(gate, [target], []) + decomposed_stim_circuit += decomp_gate + elif instruction.name in decompose_2_dict: + targets = instruction.targets_copy() + decomp_gate = StimCircuit() + basis_gates = decompose_2_dict[instruction.name] + for target_pair in zip(targets[::2], targets[1::2]): + for gate, qubit_ind in basis_gates: + qubits = [target_pair[qubit_ind[0]]] + if len(qubit_ind) == 2: + qubits.append(target_pair[qubit_ind[1]]) + decomp_gate.append(gate, qubits, []) + decomposed_stim_circuit += decomp_gate + elif instruction.name in decompose_m_dict: + arg = instruction.gate_args_copy() + target_list = instruction.targets_copy() + decomp_gate = StimCircuit() + basis_gates = decompose_m_dict[instruction.name] + if instruction.name in ("MXX", "MYY", "MZZ"): + if len(set(target_list)) < len( + target_list + ): # avoiding overlaps that result from stim broadcating + target_list = [ + target_list[i : i + 2] for i in range(0, len(target_list), 2) + ] + else: + target_list = [target_list] + for targets in target_list: + for gate, qubit_ind in basis_gates: + if len(qubit_ind) == 2: + qubit_list = [ + target_pair[q_ind].value + for target_pair in zip(targets[::2], targets[1::2]) + for q_ind in qubit_ind + ] + else: + qubit_list = [ + target_pair[qubit_ind[0]].value + for target_pair in zip(targets[::2], targets[1::2]) + ] + + if gate == "M": + inv_list = [ + ( + target_pair[0].is_inverted_result_target + + target_pair[1].is_inverted_result_target + ) + % 2 + for target_pair in zip(targets[::2], targets[1::2]) + ] + for i, inv in enumerate(inv_list): + if inv: + qubit_list[i] = target_inv(qubit_list[i]) + decomp_gate.append(gate, qubit_list, arg) + else: + decomp_gate.append(gate, qubit_list, []) + + else: + for gate, _ in basis_gates: + qubit_list = [target.value for target in target_list] + if gate == "M": + inv_list = [ + target.is_inverted_result_target for target in target_list + ] + for i, inv in enumerate(inv_list): + if inv: + qubit_list[i] = target_inv(qubit_list[i]) + decomp_gate.append(gate, qubit_list, arg) + else: + decomp_gate.append(gate, qubit_list, []) + + decomposed_stim_circuit += decomp_gate + + elif instruction.name == "MPP": + decomposed_stim_circuit += self.MPP_circuit(instruction) + elif instruction.name == "MPAD": + # MPAD does affect the stim measurement sample output, but not the detector sample. + # This is due to the fact that detectors are sensitive to changes + # wrt the perfect output, + # R 0 + # MPAD 1 + # M 0 + # DETECTOR rec[-1] rec[-2] + # The above example results in a deterministinc detector samples False, not True! + # + # MPAD can have an effect on measurements, when a CX is controlled on an MPAD bit... + # but what is the point of that anyway? + warnings.warn( + "The circuit contains MPAD instructions that are ignored in the conversion." + "This can affect the measurement outcomes, but not the detectors." + ) + pass + elif instruction.name in stim_error_list: + # do not include errors + pass + else: + gate = instruction.name + arg = instruction.gate_args_copy() + targets = instruction.targets_copy() + decomposed_stim_circuit.append(gate, targets, arg) + elif isinstance(instruction, CircuitRepeatBlock): + decomposed_stim_circuit.append( + CircuitRepeatBlock( + instruction.repeat_count, + self.decompose_stim_circuit(instruction.body_copy()), + ) + ) + + return decomposed_stim_circuit + + def MPP_circuit(self, stim_instruction): + """Handle MPP measurements.""" + MPP_stim_circuit = StimCircuit() + arg = stim_instruction.gate_args_copy() + + # break it down into individual Pauli products + target_lists = [] + target_list = [] + prev_is_combiner = True + for target in stim_instruction.targets_copy(): + if prev_is_combiner: + target_list.append(target) + prev_is_combiner = False + elif target.is_combiner: + prev_is_combiner = True + else: + target_lists.append(target_list) + target_list = [target] + prev_is_combiner = False + target_lists.append(target_list) + + for target_list in target_lists: + invert = 0 + first_target_qubit = target_list[0].value + for target in target_list: + if target.is_x_target: + MPP_stim_circuit.append("H", target.value) + elif target.is_y_target: + MPP_stim_circuit.append("S_DAG", target.value) + MPP_stim_circuit.append("H", target.value) + invert = (invert + target.is_inverted_result_target) % 2 + if target.value != first_target_qubit: + MPP_stim_circuit.append("CX", [target.value, first_target_qubit]) + if invert: + MPP_stim_circuit.append("M", target_inv(first_target_qubit), arg) + else: + MPP_stim_circuit.append("M", first_target_qubit, arg) + + for target in target_list[::-1]: + if target.value != first_target_qubit: + MPP_stim_circuit.append("CX", [target.value, first_target_qubit]) + if target.is_x_target: + MPP_stim_circuit.append("H", target.value) + elif target.is_y_target: + MPP_stim_circuit.append("H", target.value) + MPP_stim_circuit.append("S", target.value) + + return MPP_stim_circuit + + def string2nodes(self, string, **kwargs): + """ + Convert output string from circuits into a set of nodes for `DecodingGraph`. + Args: + string (string): Results string to convert. + kwargs (dict): Any additional keyword arguments. + logical (str): Logical value whose results are used ('0' as default). + all_logicals (bool): Whether to include logical nodes + irrespective of value. (False as default). + """ + + nodes = string2nodes_with_detectors( + string=string, + detectors=self.detectors, + logicals=self.logicals, + clbits=self.qc.clbits, + det_ref_values=self.det_ref_values, + **kwargs, + ) + return nodes + + def string2raw_logicals(self, string): + """ + Converts output string into a list of logical measurement outcomes + Logicals are the logical measurements produced by self.stim_detectors() + """ + _, self.logicals = self.stim_detectors() + + log_outs = string2logical_meas(string, self.logicals, self.circuit.clbits) + + return log_outs + + def _make_syndrome_graph(self): + e = self.stim_circuit.detector_error_model( + decompose_errors=True, approximate_disjoint_errors=True + ) + graph, hyperedges = detector_error_model_to_rx_graph(e, detectors=self.detectors) + return graph, hyperedges + + def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): + raise NotImplementedError + + def is_cluster_neutral(self, atypical_nodes): + raise NotImplementedError diff --git a/src/qiskit_qec/circuits/surface_code.py b/src/qiskit_qec/circuits/surface_code.py index 50162caa..4775c800 100644 --- a/src/qiskit_qec/circuits/surface_code.py +++ b/src/qiskit_qec/circuits/surface_code.py @@ -316,6 +316,13 @@ def _string2changes(self, string): return syndrome_changes + def measured_logicals(self): + if self.basis == "x": + measured_logicals = self.css_x_logical + else: + measured_logicals = self.css_z_logical + return measured_logicals + def string2raw_logicals(self, string): """ Extracts raw logicals from output string. diff --git a/src/qiskit_qec/decoders/hdrg_decoders.py b/src/qiskit_qec/decoders/hdrg_decoders.py index 4bd54dde..3f1c2bbe 100644 --- a/src/qiskit_qec/decoders/hdrg_decoders.py +++ b/src/qiskit_qec/decoders/hdrg_decoders.py @@ -40,13 +40,8 @@ def __init__( ): self.code = code_circuit - if hasattr(self.code, "_xbasis"): - if self.code._xbasis: - self.measured_logicals = self.code.css_x_logical - else: - self.measured_logicals = self.code.css_z_logical - else: - self.measured_logicals = self.code.css_z_logical + self.measured_logicals = self.code.measured_logicals() + if hasattr(self.code, "code_index"): self.code_index = self.code.code_index else: diff --git a/src/qiskit_qec/utils/stim_tools.py b/src/qiskit_qec/utils/stim_tools.py index 5c8285ea..9bbb732e 100644 --- a/src/qiskit_qec/utils/stim_tools.py +++ b/src/qiskit_qec/utils/stim_tools.py @@ -12,15 +12,17 @@ # copyright notice, and modified files need to carry a notice indicating # that they have been altered from the originals. -# pylint: disable=invalid-name, disable=no-name-in-module +# pylint: disable=invalid-name, disable=no-name-in-module, disable=unused-argument """Tools to use functionality from Stim.""" -from typing import Union, List, Dict -from math import log +from typing import Union, List, Dict, Callable +from math import log as loga from stim import Circuit as StimCircuit from stim import DetectorErrorModel as StimDetectorErrorModel from stim import DemInstruction as StimDemInstruction +from stim import DemRepeatBlock as StimDemRepeatBlock from stim import DemTarget as StimDemTarget +from stim import target_rec as StimTarget_rec import numpy as np import rustworkx as rx @@ -28,26 +30,55 @@ from qiskit import QuantumCircuit from qiskit_aer.noise.errors.quantum_error import QuantumChannelInstruction from qiskit_aer.noise import pauli_error -from qiskit_qec.utils.decoding_graph_attributes import DecodingGraphNode, DecodingGraphEdge +from qiskit_qec.utils.decoding_graph_attributes import ( + DecodingGraphNode, + DecodingGraphEdge, +) from qiskit_qec.noise.paulinoisemodel import PauliNoiseModel -def get_stim_circuits(circuit_dict: Dict[int, QuantumCircuit]): +def get_stim_circuits( + circuit: Union[QuantumCircuit, List], + detectors: List[Dict] = None, + logicals: List[Dict] = None, +): """Converts compatible qiskit circuits to stim circuits. Dictionaries are not complete. For the stim definitions see: https://github.com/quantumlib/Stim/blob/main/doc/gates.md Args: - circuit_dict: Compatible gates are paulis, controlled paulis, h, s, + circuit: Compatible gates are Paulis, controlled Paulis, h, s, and sdg, swap, reset, measure and barrier. Compatible noise operators correspond to a single or two qubit pauli channel. + detectors: A list of measurement comparisons. A measurement comparison + (detector) is either a list of measurements given by a the name and index + of the classical bit or a list of dictionaries, with a mandatory clbits + key containing the classical bits. A dictionary can contain keys like + 'qubits', 'time', 'basis' etc. + logicals: A list of logical measurements. A logical measurement is a + list of classical bits whose total parity is the logical eigenvalue. + Again it can be a list of dictionaries. Returns: stim_circuits, stim_measurement_data """ - stim_circuits = {} - stim_measurement_data = {} - for circ_label, circuit in circuit_dict.items(): + + if detectors is None: + detectors = [{}] + if logicals is None: + logicals = [{}] + + if len(detectors) > 0 and isinstance(detectors[0], List): + detectors = [{"clbits": det, "qubits": [], "time": 0} for det in detectors] + + if len(logicals) > 0 and isinstance(logicals[0], List): + logicals = [{"clbits": log} for log in logicals] + + stim_circuits = [] + stim_measurement_data = [] + if isinstance(circuit, QuantumCircuit): + circuit = [circuit] + for circ in circuit: stim_circuit = StimCircuit() qiskit_to_stim_dict = { @@ -58,7 +89,7 @@ def get_stim_circuits(circuit_dict: Dict[int, QuantumCircuit]): "h": "H", "s": "S", "sdg": "S_DAG", - "cx": "CNOT", + "cx": "CX", "cy": "CY", "cz": "CZ", "swap": "SWAP", @@ -66,7 +97,16 @@ def get_stim_circuits(circuit_dict: Dict[int, QuantumCircuit]): "measure": "M", "barrier": "TICK", } - pauli_error_1_stim_order = {"id": 0, "I": 0, "X": 1, "x": 1, "Y": 2, "y": 2, "Z": 3, "z": 3} + pauli_error_1_stim_order = { + "id": 0, + "I": 0, + "X": 1, + "x": 1, + "Y": 2, + "y": 2, + "Z": 3, + "z": 3, + } pauli_error_2_stim_order = { "II": 0, "IX": 1, @@ -87,17 +127,22 @@ def get_stim_circuits(circuit_dict: Dict[int, QuantumCircuit]): } measurement_data = [] - register_offset = {} - previous_offset = 0 - for inst, qargs, cargs in circuit.data: + qreg_offset = {} + creg_offset = {} + prevq_offset = 0 + prevc_offset = 0 + for inst, qargs, cargs in circ.data: for qubit in qargs: - if qubit._register.name not in register_offset: - register_offset[qubit._register.name] = previous_offset - previous_offset += qubit._register.size + if qubit._register.name not in qreg_offset: + qreg_offset[qubit._register.name] = prevq_offset + prevq_offset += qubit._register.size + for bit in cargs: + if bit._register.name not in creg_offset: + creg_offset[bit._register.name] = prevc_offset + prevc_offset += bit._register.size qubit_indices = [ - qargs[i]._index + register_offset[qargs[i]._register.name] - for i in range(len(qargs)) + qargs[i]._index + qreg_offset[qargs[i]._register.name] for i in range(len(qargs)) ] if isinstance(inst, QuantumChannelInstruction): @@ -121,26 +166,71 @@ def get_stim_circuits(circuit_dict: Dict[int, QuantumCircuit]): # Gates and measurements if inst.name in qiskit_to_stim_dict: if len(cargs) > 0: # keeping track of measurement indices in stim - measurement_data.append( - [ - cargs[0]._index + register_offset[qargs[0]._register.name], - qargs[0]._register.name, - ] - ) + measurement_data.append([cargs[0]._register.name, cargs[0]._index]) + if qiskit_to_stim_dict[inst.name] == "TICK": # barrier stim_circuit.append("TICK") + elif inst.condition is not None: # handle c_ifs + if inst.name in "xyz": + if inst.condition[1] == 1: + clbit = inst.condition[0] + stim_circuit.append( + qiskit_to_stim_dict["c" + inst.name], + [ + StimTarget_rec( + measurement_data.index( + [clbit._register.name, clbit._index] + ) + - len(measurement_data) + ), + qubit_indices[0], + ], + ) + else: + raise Exception( + "Classically controlled gate must be conditioned on bit value 1" + ) + else: + raise Exception( + "Classically controlled " + inst.name + " gate is not supported" + ) else: # gates/measurements acting on qubits stim_circuit.append(qiskit_to_stim_dict[inst.name], qubit_indices) else: raise Exception("Unexpected operations: " + str([inst, qargs, cargs])) - stim_circuits[circ_label] = stim_circuit - stim_measurement_data[circ_label] = measurement_data + if detectors != [{}]: + for det in detectors: + stim_record_targets = [] + for reg, ind in det["clbits"]: + stim_record_targets.append( + StimTarget_rec(measurement_data.index([reg, ind]) - len(measurement_data)) + ) + if det["time"] != []: + stim_circuit.append( + "DETECTOR", stim_record_targets, det["qubits"] + [det["time"]] + ) + else: + stim_circuit.append("DETECTOR", stim_record_targets, []) + if logicals != [{}]: + for log_ind, log in enumerate(logicals): + stim_record_targets = [] + for reg, ind in log["clbits"]: + stim_record_targets.append( + StimTarget_rec(measurement_data.index([reg, ind]) - len(measurement_data)) + ) + stim_circuit.append("OBSERVABLE_INCLUDE", stim_record_targets, log_ind) + + stim_circuits.append(stim_circuit) + stim_measurement_data.append(measurement_data) + return stim_circuits, stim_measurement_data def get_counts_via_stim( - circuits: Union[List, QuantumCircuit], shots: int = 4000, noise_model: PauliNoiseModel = None + circuits: Union[List, QuantumCircuit], + shots: int = 4000, + noise_model: PauliNoiseModel = None, ): """Returns a qiskit compatible dictionary of measurement outcomes @@ -162,17 +252,17 @@ def get_counts_via_stim( counts = [] for circuit in circuits: - stim_circuits, stim_measurement_data = get_stim_circuits({"": circuit}) - stim_circuit = stim_circuits[""] - measurement_data = stim_measurement_data[""] + stim_circuits, stim_measurement_data = get_stim_circuits(circuit) + stim_circuit = stim_circuits[0] + measurement_data = stim_measurement_data[0] stim_samples = stim_circuit.compile_sampler().sample(shots=shots) qiskit_counts = {} for stim_sample in stim_samples: - prev_reg = measurement_data[-1][1] + prev_reg = measurement_data[-1][0] qiskit_count = "" for idx, meas in enumerate(measurement_data[::-1]): - _, reg = meas + reg, _ = meas if reg != prev_reg: qiskit_count += " " qiskit_count += str(int(stim_sample[-idx - 1])) @@ -189,30 +279,101 @@ def get_counts_via_stim( return counts -def detector_error_model_to_rx_graph(model: StimDetectorErrorModel) -> rx.PyGraph: +def iter_flatten_model( + model: StimDetectorErrorModel, + handle_error: Callable[[float, List[int], List[int]], None], + handle_detector_coords: Callable[[int, np.ndarray], None], + detectors: List[Dict], + hyperedges: List[Dict], +): + """ + This function have been copied from the built-in method of + stim: stim.Circuit.generated("surface_code:rotated_memory_z",...) + """ + + det_offset = 0 + + def _helper(m: StimDetectorErrorModel, reps: int): + nonlocal det_offset + for _ in range(reps): + for instruction in m: + if isinstance(instruction, StimDemRepeatBlock): + _helper(instruction.body_copy(), instruction.repeat_count) + elif isinstance(instruction, StimDemInstruction): + if instruction.type == "error": + dets: List[int] = [] + frames: List[int] = [] + t: StimDemTarget + p = instruction.args_copy()[0] + hyperedge = {} + for t in instruction.targets_copy(): + if t.is_relative_detector_id(): + dets.append(t.val + det_offset) + elif t.is_logical_observable_id(): + frames.append(t.val) + elif t.is_separator(): + # Treat each component of a decomposed error as an independent error. + handle_error(p, dets, frames, hyperedge) + frames = [] + dets = [] + # Handle last component. + handle_error(p, dets, frames, hyperedge) + if len(hyperedge) > 1: + hyperedges.append(hyperedge) + elif instruction.type == "shift_detectors": + det_offset += instruction.targets_copy()[0] + elif instruction.type == "detector": + t = instruction.targets_copy()[0] + det_ind = t.val + det_offset + if detectors == [{}]: + a = np.array(instruction.args_copy()) + time = a[-1] + qubits = [int(qubit_ind) for qubit_ind in a[:-1]] + det = {} + else: + det = detectors[det_ind].copy() + time = det.pop("time") + qubits = det.pop("qubits") + del det["clbits"] + for t in instruction.targets_copy(): + handle_detector_coords( + detector_index=det_ind, + time=time, + qubits=qubits, + det_props=det, + ) + elif instruction.type == "logical_observable": + pass + else: + raise NotImplementedError() + else: + raise NotImplementedError() + + _helper(model, 1) + + +def detector_error_model_to_rx_graph( + model: StimDetectorErrorModel, detectors: List[Dict] = None +) -> rx.PyGraph: """Convert a stim error model into a RustworkX graph. It assumes that the stim circuit does not contain repeat blocks. Later on repeat blocks should be handled to make this function compatible with user-defined stim circuits. + + Args: + detectors: + coordinate included as the last element for every detector in the stim detector error model """ + if detectors is None: + detectors = [{}] + g = rx.PyGraph(multigraph=False) index_to_DecodingGraphNode = {} - for instruction in model: - if instruction.type == "detector": - a = np.array(instruction.args_copy()) - time = a[-1] - qubits = [int(qubit_ind) for qubit_ind in a[:-1]] - for t in instruction.targets_copy(): - node = DecodingGraphNode(index=t.val, time=time, qubits=qubits) - index_to_DecodingGraphNode[t.val] = node - g.add_node(node) - - trivial_boundary_node = DecodingGraphNode(index=model.num_detectors, time=0, is_boundary=True) - g.add_node(trivial_boundary_node) - index_to_DecodingGraphNode[model.num_detectors] = trivial_boundary_node + def skip_error(p: float, dets: List[int], frame_changes: List[int], hyperedge: Dict): + pass def handle_error(p: float, dets: List[int], frame_changes: List[int], hyperedge: Dict): if p == 0: @@ -244,51 +405,179 @@ def handle_error(p: float, dets: List[int], frame_changes: List[int], hyperedge: ) edge = DecodingGraphEdge( qubits=qubits, - weight=log((1 - p) / p), + weight=loga((1 - p) / p), properties={"fault_ids": set(frame_changes), "error_probability": p}, ) g.add_edge(dets[0], dets[1], edge) hyperedge[dets[0], dets[1]] = edge + def skip_detector_coords(detector_index: int, time, qubits, det_props): + pass + + def handle_detector_coords(detector_index: int, time, qubits, det_props): + node = DecodingGraphNode(index=detector_index, time=time, qubits=qubits) + node.properties = det_props + index_to_DecodingGraphNode[detector_index] = node + g.add_node(node) + hyperedges = [] - for instruction in model: - if isinstance(instruction, StimDemInstruction): - if instruction.type == "error": - dets: List[int] = [] - frames: List[int] = [] - t: StimDemTarget - p = instruction.args_copy()[0] - hyperedge = {} - for t in instruction.targets_copy(): - if t.is_relative_detector_id(): - dets.append(t.val) - elif t.is_logical_observable_id(): - frames.append(t.val) - elif t.is_separator(): - # Treat each component of a decomposed error as an independent error. - handle_error(p, dets, frames, hyperedge) - frames = [] - dets = [] - # Handle last component. - handle_error(p, dets, frames, hyperedge) - if len(hyperedge) > 1: - hyperedges.append(hyperedge) - elif instruction.type == "detector": - pass - elif instruction.type == "logical_observable": - pass - else: - raise NotImplementedError() - else: - raise NotImplementedError() + iter_flatten_model( + model, + handle_error=skip_error, + handle_detector_coords=handle_detector_coords, + detectors=detectors, + hyperedges=hyperedges, + ) + + trivial_boundary_node = DecodingGraphNode(index=model.num_detectors, time=0, is_boundary=True) + g.add_node(trivial_boundary_node) + index_to_DecodingGraphNode[model.num_detectors] = trivial_boundary_node + + iter_flatten_model( + model, + handle_error=handle_error, + handle_detector_coords=skip_detector_coords, + detectors=detectors, + hyperedges=hyperedges, + ) return g, hyperedges +def string2nodes_with_detectors( + string: str, + detectors: List[Dict], + logicals: List[Dict], + clbits: QuantumCircuit.clbits, + det_ref_values: Union[List, int] = 0, + **kwargs, +): + """ + Convert output string from circuits into a set of nodes for + `DecodingGraph`. + Args: + string (string): Results string to convert. + detectors: A list of measurement comparisons. A measurement comparison + (detector) is either a list of measurements given by a the name and index + of the classical bit or a list of dictionaries, with a mandatory clbits + key containing the classical bits. A dictionary can contain keys like + 'qubits', 'time', 'basis' etc. + logicals: A list of logical measurements. A logical measurement is a + list of classical bits whose total parity is the logical eigenvalue. + Again it can be a list of dictionaries. + clbits: classical bits of the qiskit circuit, needed to identify + measurements in the output string + det_ref_values: Reference value for the detector outcomes, 0 by default + + kwargs (dict): Any additional keyword arguments. + logical (str): Logical value whose results are used ('0' as default). + all_logicals (bool): Whether to include logical nodes + irrespective of value. (False as default). + """ + + output_bits = np.array([int(char) for char in string.replace(" ", "")[::-1]]) + + clbit_dict = {(clbit._register.name, clbit._index): clind for clind, clbit in enumerate(clbits)} + + if isinstance(det_ref_values, int): + det_ref_values = [det_ref_values] * len(detectors) + + nodes = [] + for ind, det in enumerate(detectors): + det = det.copy() + outcomes = [clbit_dict[clbit_key] for clbit_key in det.pop("clbits")] + if sum(output_bits[outcomes]) % 2 != det_ref_values[ind]: + node = DecodingGraphNode(time=det.pop("time"), qubits=det.pop("qubits"), index=ind) + node.properties = det + nodes.append(node) + + log_nodes = string2rawlogicals_with_detectors( + string=string, logicals=logicals, clbits=clbits, start_ind=len(detectors), **kwargs + ) + + for node in log_nodes: + nodes.append(node) + + return nodes + + +def string2rawlogicals_with_detectors( + string: str, + logicals: List[Dict], + clbits: QuantumCircuit.clbits, + start_ind: int = 0, + **kwargs, +): + """ + Convert output string from circuits into raw logical values. + """ + + all_logicals = kwargs.get("all_logicals") + logical = kwargs.get("logical") + if logical is None: + logical = "0" + + output_bits = np.array([int(char) for char in string.replace(" ", "")[::-1]]) + + clbit_dict = {(clbit._register.name, clbit._index): clind for clind, clbit in enumerate(clbits)} + + nodes = [] + for index, logical_op in enumerate(logicals, start=start_ind): + logical_out = 0 + for q in logical_op["clbits"]: + qind = clbit_dict[q] + logical_out += output_bits[qind] + logical_out = logical_out % 2 + + if all_logicals or str(logical_out) != logical: + node = DecodingGraphNode( + is_boundary=True, + qubits=[], + index=index, + ) + nodes.append(node) + + return nodes + + +def string2logical_meas( + string: str, + outcomes_in_logical: List[Dict], + clbits: QuantumCircuit.clbits, +): + """ + Args: + string (string): Results string from qiskit circuit + outcomes_in_logical: the detector-style logical outcome + clbits: classical bits of the qiskit circuit, needed to identify + measurements in the output string + """ + + output_bits = np.array([int(char) for char in string.replace(" ", "")[::-1]]) + + clbit_dict = {(clbit._register.name, clbit._index): clind for clind, clbit in enumerate(clbits)} + + log_outs = [] + for logical_op in outcomes_in_logical: + logical_out = 0 + for q in logical_op["clbits"]: + qind = clbit_dict[q] + logical_out += output_bits[qind] + logical_out = logical_out % 2 + log_outs.append(logical_out) + + return log_outs + + def noisify_circuit(circuits: Union[List, QuantumCircuit], noise_model: PauliNoiseModel): """ Inserts error operations into a circuit according to a pauli noise model. + Handles idling errors in the form of custom gates "idle_#" which are assumed to + encode the identity gate only. + qc = QuantumCircuit(1, name='idle_1') + qc.i(0) + idle_1 = qc.to_instruction() Args: circuits: Circuit or list thereof to which noise is added. @@ -305,10 +594,9 @@ def noisify_circuit(circuits: Union[List, QuantumCircuit], noise_model: PauliNoi # create pauli errors for all errors in noise model errors = {} for g, noise in noise_model.to_dict().items(): - errors[g] = [] - for pauli, prob in noise["chan"].items(): - pauli = pauli.upper() - errors[g].append(pauli_error([(pauli, prob), ("I" * len(pauli), 1 - prob)])) + paulis = [pauli.upper() for pauli in noise["chan"].keys()] + probs = list(noise["chan"].values()) + errors[g] = pauli_error(list(zip(paulis, probs))) noisy_circuits = [] for qc in circuits: @@ -327,11 +615,11 @@ def noisify_circuit(circuits: Union[List, QuantumCircuit], noise_model: PauliNoi noisy_qc.append(gate) # then the error if g in errors: - for error_op in errors[g]: - noisy_qc.append(error_op, qubits) + noisy_qc.append(errors[g], qubits) # add gate if it needs to go after the error if not pre_error: - noisy_qc.append(gate) + if not g.startswith("idle_"): + noisy_qc.append(gate) noisy_circuits.append(noisy_qc) diff --git a/test/code_circuits/test_css_codes_with_stim.py b/test/code_circuits/test_css_codes_with_stim.py index 2a6cb9b4..8e2b99d8 100644 --- a/test/code_circuits/test_css_codes_with_stim.py +++ b/test/code_circuits/test_css_codes_with_stim.py @@ -19,6 +19,7 @@ from qiskit_qec.codes.hhc import HHC from qiskit_qec.circuits.css_code import CSSCodeCircuit from qiskit_qec.decoders.decoding_graph import DecodingGraph +from qiskit_qec.utils.stim_tools import get_stim_circuits class TestCircuitMatcher(unittest.TestCase): @@ -38,7 +39,10 @@ def log_failure_dists(self, error_rate: float): css_code = CSSCodeCircuit(code, T=d, basis="x", noise_model=(error_rate, error_rate)) graph = DecodingGraph(css_code).graph m = pymatching.Matching(graph) - stim_circuit = css_code.stim_circuit_with_detectors()["0"] + detectors, logicals = css_code.stim_detectors() + stim_circuit = get_stim_circuits( + css_code.noisy_circuit["0"], detectors=detectors, logicals=logicals + )[0][0] stim_sampler = stim_circuit.compile_detector_sampler() num_correct = 0 stim_samples = stim_sampler.sample(num_shots, append_observables=True)