From 4e4e34a031811c5985e8ca4bb25f05da17e8048b Mon Sep 17 00:00:00 2001 From: James Wootton Date: Thu, 16 Nov 2023 16:48:03 +0100 Subject: [PATCH] 404 stim compatibility (#5) * Update stim_code_circuit.py * Update stim_tools.py * Update stim_tools.py * more linting * remove decoder dependence on css_logical * final linting (hopefully) --- src/qiskit_qec/circuits/__init__.py | 1 + src/qiskit_qec/circuits/code_circuit.py | 8 ++++++++ src/qiskit_qec/circuits/css_code.py | 7 +++++++ src/qiskit_qec/circuits/repetition_code.py | 8 +++++++- src/qiskit_qec/circuits/stim_code_circuit.py | 9 ++++----- src/qiskit_qec/circuits/surface_code.py | 7 +++++++ src/qiskit_qec/decoders/hdrg_decoders.py | 9 ++------- src/qiskit_qec/utils/stim_tools.py | 19 +++++++++++++------ test/code_circuits/test_rep_codes.py | 4 ++-- 9 files changed, 51 insertions(+), 21 deletions(-) diff --git a/src/qiskit_qec/circuits/__init__.py b/src/qiskit_qec/circuits/__init__.py index 23bd413b..433410af 100644 --- a/src/qiskit_qec/circuits/__init__.py +++ b/src/qiskit_qec/circuits/__init__.py @@ -34,3 +34,4 @@ from .repetition_code import RepetitionCodeCircuit, ArcCircuit from .surface_code import SurfaceCodeCircuit from .css_code import CSSCodeCircuit +from .stim_code_circuit import StimCodeCircuit diff --git a/src/qiskit_qec/circuits/code_circuit.py b/src/qiskit_qec/circuits/code_circuit.py index 916e9698..714b27fc 100644 --- a/src/qiskit_qec/circuits/code_circuit.py +++ b/src/qiskit_qec/circuits/code_circuit.py @@ -52,6 +52,14 @@ def string2nodes(self, string, **kwargs): """ pass + @abstractmethod + def measured_logicals(self): + """ + Returns a list of logical operators, each expressed as a list of qubits for which + the parity of the final readouts corresponds to the raw logical readout. + """ + pass + @abstractmethod def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): """ diff --git a/src/qiskit_qec/circuits/css_code.py b/src/qiskit_qec/circuits/css_code.py index 415bcb3d..9133b133 100644 --- a/src/qiskit_qec/circuits/css_code.py +++ b/src/qiskit_qec/circuits/css_code.py @@ -186,6 +186,13 @@ def _get_code_properties(self): self.css_x_logical = self.logical_x self.css_z_logical = self.logical_z + def measured_logicals(self): + if self.basis == "x": + measured_logicals = self.logical_x + else: + measured_logicals = self.logical_z + return measured_logicals + def _prepare_initial_state(self, qc, qregs, state): if state[0] == "1": if self.basis == "z": diff --git a/src/qiskit_qec/circuits/repetition_code.py b/src/qiskit_qec/circuits/repetition_code.py index 40bb3f52..8123289e 100644 --- a/src/qiskit_qec/circuits/repetition_code.py +++ b/src/qiskit_qec/circuits/repetition_code.py @@ -241,6 +241,9 @@ def readout(self): self.circuit[log].add_register(self.code_bit) self.circuit[log].measure(self.code_qubit, self.code_bit) + def measured_logicals(self): + return [[0]] + def _process_string(self, string): # logical readout taken from measured_log = string[0] + " " + string[self.d - 1] @@ -349,7 +352,7 @@ def string2raw_logicals(self, string): Returns: list: Raw values for logical operators that correspond to nodes. """ - return _separate_string(self._process_string(string))[0] + return string.split(" ", maxsplit=1)[0][-1] def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): """ @@ -964,6 +967,9 @@ def _readout(self): qc.add_register(self.code_bit) qc.measure(self.code_qubit, self.code_bit) + def measured_logicals(self): + return [[self.z_logicals[0]]] + def _process_string(self, string): # logical readout taken from assigned qubits measured_log = "" diff --git a/src/qiskit_qec/circuits/stim_code_circuit.py b/src/qiskit_qec/circuits/stim_code_circuit.py index cb130657..888fc9d3 100644 --- a/src/qiskit_qec/circuits/stim_code_circuit.py +++ b/src/qiskit_qec/circuits/stim_code_circuit.py @@ -12,8 +12,9 @@ # copyright notice, and modified files need to carry a notice indicating # that they have been altered from the originals. -# pylint: disable=invalid-name, disable=no-name-in-module +# pylint: disable=invalid-name, disable=no-name-in-module, disable=no-member +"""Generates CodeCircuits from stim circuits""" import warnings from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister @@ -112,9 +113,7 @@ def _helper(stim_circuit: StimCircuit, reps: int): rep_block_count += 1 elif isinstance(instruction, CircuitInstruction): inst_name = instruction.name - if inst_name == "QUBIT_COORDS": - m = 1 - elif inst_name in single_qubit_gate_dict: + if inst_name in single_qubit_gate_dict: qubits = [target.value for target in instruction.targets_copy()] for q in qubits: self.qc.append(single_qubit_gate_dict[inst_name], qargs=[q]) @@ -211,7 +210,7 @@ def _helper(stim_circuit: StimCircuit, reps: int): # further code parameters try: self.d = len(self.stim_circuit.shortest_graphlike_error()) # code distance - except: + except ValueError: self.d = 0 self.n = stim_circuit.num_qubits # the number of rounds is not necessarily well-defined (Floquet codes etc.) diff --git a/src/qiskit_qec/circuits/surface_code.py b/src/qiskit_qec/circuits/surface_code.py index 50162caa..4775c800 100644 --- a/src/qiskit_qec/circuits/surface_code.py +++ b/src/qiskit_qec/circuits/surface_code.py @@ -316,6 +316,13 @@ def _string2changes(self, string): return syndrome_changes + def measured_logicals(self): + if self.basis == "x": + measured_logicals = self.css_x_logical + else: + measured_logicals = self.css_z_logical + return measured_logicals + def string2raw_logicals(self, string): """ Extracts raw logicals from output string. diff --git a/src/qiskit_qec/decoders/hdrg_decoders.py b/src/qiskit_qec/decoders/hdrg_decoders.py index 4bd54dde..3f1c2bbe 100644 --- a/src/qiskit_qec/decoders/hdrg_decoders.py +++ b/src/qiskit_qec/decoders/hdrg_decoders.py @@ -40,13 +40,8 @@ def __init__( ): self.code = code_circuit - if hasattr(self.code, "_xbasis"): - if self.code._xbasis: - self.measured_logicals = self.code.css_x_logical - else: - self.measured_logicals = self.code.css_z_logical - else: - self.measured_logicals = self.code.css_z_logical + self.measured_logicals = self.code.measured_logicals() + if hasattr(self.code, "code_index"): self.code_index = self.code.code_index else: diff --git a/src/qiskit_qec/utils/stim_tools.py b/src/qiskit_qec/utils/stim_tools.py index 107e8fb8..a55e6ddf 100644 --- a/src/qiskit_qec/utils/stim_tools.py +++ b/src/qiskit_qec/utils/stim_tools.py @@ -16,7 +16,7 @@ """Tools to use functionality from Stim.""" from typing import Union, List, Dict, Callable -from math import log +from math import log as loga from stim import Circuit as StimCircuit from stim import DetectorErrorModel as StimDetectorErrorModel from stim import DemInstruction as StimDemInstruction @@ -404,7 +404,7 @@ def handle_error(p: float, dets: List[int], frame_changes: List[int], hyperedge: ) edge = DecodingGraphEdge( qubits=qubits, - weight=log((1 - p) / p), + weight=loga((1 - p) / p), properties={"fault_ids": set(frame_changes), "error_probability": p}, ) g.add_edge(dets[0], dets[1], edge) @@ -457,10 +457,17 @@ def string2nodes_with_detectors( `DecodingGraph`. Args: string (string): Results string to convert. - detectors: - logicals: - clbits: - det_ref_values: + detectors: A list of measurement comparisons. A measurement comparison + (detector) is either a list of measurements given by a the name and index + of the classical bit or a list of dictionaries, with a mandatory clbits + key containing the classical bits. A dictionary can contain keys like + 'qubits', 'time', 'basis' etc. + logicals: A list of logical measurements. A logical measurement is a + list of classical bits whose total parity is the logical eigenvalue. + Again it can be a list of dictionaries. + clbits: classical bits of the qiskit circuit, needed to identify + measurements in the output string + det_ref_values: Reference value for the detector outcomes, 0 by default kwargs (dict): Any additional keyword arguments. logical (str): Logical value whose results are used ('0' as default). diff --git a/test/code_circuits/test_rep_codes.py b/test/code_circuits/test_rep_codes.py index 1b65cd3c..7a76430c 100644 --- a/test/code_circuits/test_rep_codes.py +++ b/test/code_circuits/test_rep_codes.py @@ -538,7 +538,7 @@ def clustering_decoder_test( # now run them all and check it works for c, code in enumerate(codes): decoding_graph = DecodingGraph(code) - if c == 3 and Decoder is UnionFindDecoder: + if c >= 0 and Decoder is UnionFindDecoder: decoder = Decoder(code, decoding_graph=decoding_graph, use_peeling=False) else: decoder = Decoder(code, decoding_graph=decoding_graph) @@ -555,7 +555,7 @@ def clustering_decoder_test( for j, z_logical in enumerate(decoder.measured_logicals): error = corrected_z_logicals[j] != 1 if error: - error_num = string.count("0") + error_num = string.split(" ", maxsplit=1)[0].count("0") if error_num < min_error_num: min_error_num = error_num min_error_string = string