Skip to content

Commit

Permalink
404 stim compatibility (#5)
Browse files Browse the repository at this point in the history
* Update stim_code_circuit.py

* Update stim_tools.py

* Update stim_tools.py

* more linting

* remove decoder dependence on css_logical

* final linting (hopefully)
  • Loading branch information
quantumjim authored Nov 16, 2023
1 parent 1035f97 commit 4e4e34a
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 21 deletions.
1 change: 1 addition & 0 deletions src/qiskit_qec/circuits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@
from .repetition_code import RepetitionCodeCircuit, ArcCircuit
from .surface_code import SurfaceCodeCircuit
from .css_code import CSSCodeCircuit
from .stim_code_circuit import StimCodeCircuit
8 changes: 8 additions & 0 deletions src/qiskit_qec/circuits/code_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ def string2nodes(self, string, **kwargs):
"""
pass

@abstractmethod
def measured_logicals(self):
"""
Returns a list of logical operators, each expressed as a list of qubits for which
the parity of the final readouts corresponds to the raw logical readout.
"""
pass

@abstractmethod
def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False):
"""
Expand Down
7 changes: 7 additions & 0 deletions src/qiskit_qec/circuits/css_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,13 @@ def _get_code_properties(self):
self.css_x_logical = self.logical_x
self.css_z_logical = self.logical_z

def measured_logicals(self):
if self.basis == "x":
measured_logicals = self.logical_x
else:
measured_logicals = self.logical_z
return measured_logicals

def _prepare_initial_state(self, qc, qregs, state):
if state[0] == "1":
if self.basis == "z":
Expand Down
8 changes: 7 additions & 1 deletion src/qiskit_qec/circuits/repetition_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ def readout(self):
self.circuit[log].add_register(self.code_bit)
self.circuit[log].measure(self.code_qubit, self.code_bit)

def measured_logicals(self):
return [[0]]

def _process_string(self, string):
# logical readout taken from
measured_log = string[0] + " " + string[self.d - 1]
Expand Down Expand Up @@ -349,7 +352,7 @@ def string2raw_logicals(self, string):
Returns:
list: Raw values for logical operators that correspond to nodes.
"""
return _separate_string(self._process_string(string))[0]
return string.split(" ", maxsplit=1)[0][-1]

def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False):
"""
Expand Down Expand Up @@ -964,6 +967,9 @@ def _readout(self):
qc.add_register(self.code_bit)
qc.measure(self.code_qubit, self.code_bit)

def measured_logicals(self):
return [[self.z_logicals[0]]]

def _process_string(self, string):
# logical readout taken from assigned qubits
measured_log = ""
Expand Down
9 changes: 4 additions & 5 deletions src/qiskit_qec/circuits/stim_code_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

# pylint: disable=invalid-name, disable=no-name-in-module
# pylint: disable=invalid-name, disable=no-name-in-module, disable=no-member

"""Generates CodeCircuits from stim circuits"""
import warnings

from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister
Expand Down Expand Up @@ -112,9 +113,7 @@ def _helper(stim_circuit: StimCircuit, reps: int):
rep_block_count += 1
elif isinstance(instruction, CircuitInstruction):
inst_name = instruction.name
if inst_name == "QUBIT_COORDS":
m = 1
elif inst_name in single_qubit_gate_dict:
if inst_name in single_qubit_gate_dict:
qubits = [target.value for target in instruction.targets_copy()]
for q in qubits:
self.qc.append(single_qubit_gate_dict[inst_name], qargs=[q])
Expand Down Expand Up @@ -211,7 +210,7 @@ def _helper(stim_circuit: StimCircuit, reps: int):
# further code parameters
try:
self.d = len(self.stim_circuit.shortest_graphlike_error()) # code distance
except:
except ValueError:
self.d = 0
self.n = stim_circuit.num_qubits
# the number of rounds is not necessarily well-defined (Floquet codes etc.)
Expand Down
7 changes: 7 additions & 0 deletions src/qiskit_qec/circuits/surface_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,13 @@ def _string2changes(self, string):

return syndrome_changes

def measured_logicals(self):
if self.basis == "x":
measured_logicals = self.css_x_logical
else:
measured_logicals = self.css_z_logical
return measured_logicals

def string2raw_logicals(self, string):
"""
Extracts raw logicals from output string.
Expand Down
9 changes: 2 additions & 7 deletions src/qiskit_qec/decoders/hdrg_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,8 @@ def __init__(
):
self.code = code_circuit

if hasattr(self.code, "_xbasis"):
if self.code._xbasis:
self.measured_logicals = self.code.css_x_logical
else:
self.measured_logicals = self.code.css_z_logical
else:
self.measured_logicals = self.code.css_z_logical
self.measured_logicals = self.code.measured_logicals()

if hasattr(self.code, "code_index"):
self.code_index = self.code.code_index
else:
Expand Down
19 changes: 13 additions & 6 deletions src/qiskit_qec/utils/stim_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

"""Tools to use functionality from Stim."""
from typing import Union, List, Dict, Callable
from math import log
from math import log as loga
from stim import Circuit as StimCircuit
from stim import DetectorErrorModel as StimDetectorErrorModel
from stim import DemInstruction as StimDemInstruction
Expand Down Expand Up @@ -404,7 +404,7 @@ def handle_error(p: float, dets: List[int], frame_changes: List[int], hyperedge:
)
edge = DecodingGraphEdge(
qubits=qubits,
weight=log((1 - p) / p),
weight=loga((1 - p) / p),
properties={"fault_ids": set(frame_changes), "error_probability": p},
)
g.add_edge(dets[0], dets[1], edge)
Expand Down Expand Up @@ -457,10 +457,17 @@ def string2nodes_with_detectors(
`DecodingGraph`.
Args:
string (string): Results string to convert.
detectors:
logicals:
clbits:
det_ref_values:
detectors: A list of measurement comparisons. A measurement comparison
(detector) is either a list of measurements given by a the name and index
of the classical bit or a list of dictionaries, with a mandatory clbits
key containing the classical bits. A dictionary can contain keys like
'qubits', 'time', 'basis' etc.
logicals: A list of logical measurements. A logical measurement is a
list of classical bits whose total parity is the logical eigenvalue.
Again it can be a list of dictionaries.
clbits: classical bits of the qiskit circuit, needed to identify
measurements in the output string
det_ref_values: Reference value for the detector outcomes, 0 by default
kwargs (dict): Any additional keyword arguments.
logical (str): Logical value whose results are used ('0' as default).
Expand Down
4 changes: 2 additions & 2 deletions test/code_circuits/test_rep_codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def clustering_decoder_test(
# now run them all and check it works
for c, code in enumerate(codes):
decoding_graph = DecodingGraph(code)
if c == 3 and Decoder is UnionFindDecoder:
if c >= 0 and Decoder is UnionFindDecoder:
decoder = Decoder(code, decoding_graph=decoding_graph, use_peeling=False)
else:
decoder = Decoder(code, decoding_graph=decoding_graph)
Expand All @@ -555,7 +555,7 @@ def clustering_decoder_test(
for j, z_logical in enumerate(decoder.measured_logicals):
error = corrected_z_logicals[j] != 1
if error:
error_num = string.count("0")
error_num = string.split(" ", maxsplit=1)[0].count("0")
if error_num < min_error_num:
min_error_num = error_num
min_error_string = string
Expand Down

0 comments on commit 4e4e34a

Please sign in to comment.