Skip to content

Commit

Permalink
add predecoders (#410)
Browse files Browse the repository at this point in the history
* add predecoders

* speed up decoding for linear arcs

---------

Co-authored-by: grace-harper <119029214+grace-harper@users.noreply.github.com>
  • Loading branch information
quantumjim and grace-harper authored Jan 19, 2024
1 parent 167e259 commit c9cf30d
Show file tree
Hide file tree
Showing 9 changed files with 276 additions and 364 deletions.
73 changes: 71 additions & 2 deletions src/qiskit_qec/circuits/repetition_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,11 @@ def _get_cycles(self):
"""

self.link_graph = self._get_link_graph()
self.degree = {}
for n, q in enumerate(self.link_graph.nodes()):
self.degree[q] = self.link_graph.degree(n)
degrees = list(self.degree.values())
self._linear = degrees.count(1) == 2 and degrees.count(2) == len(degrees) - 2
lg_edges = set(self.link_graph.edge_list())
lg_nodes = self.link_graph.nodes()
ng = nx.Graph()
Expand Down Expand Up @@ -1357,8 +1362,11 @@ def is_cluster_neutral(self, atypical_nodes: dict):
Args:
atypical_nodes: dictionary in the form of the return value of string2nodes
"""
neutral, logicals, _ = self.check_nodes(atypical_nodes)
return neutral and not logicals
if self._linear:
return not bool(len(atypical_nodes) % 2)
else:
neutral, logicals, _ = self.check_nodes(atypical_nodes)
return neutral and not logicals

def transpile(self, backend, echo=("X", "X"), echo_num=(2, 0)):
"""
Expand Down Expand Up @@ -1649,3 +1657,64 @@ def get_error_coords(
return error_coords, sample_coords
else:
return error_coords

def clean_code(self, string):
"""
Given an output string of the code, obvious code qubit errors are identified and their effects
are removed.
Args:
string (str): Output string of the code.
Returns:
string (str): Modifed output string of the code.
"""

# get the parities for the rounds and turn them into lists of integers
# (also turn them the right way around)
parities = []
for rstring in string.split(" ")[1:]:
parities.append([int(p) for p in rstring][::-1])
parities = parities[::-1]

# calculate the final parities from the final readout and add them on
final = string.split(" ")[0]
final_parities = [0] * self.num_qubits[1]
for c0, a, c1 in self.links:
final_parities[-self.link_index[a] - 1] = (
int(final[-self.code_index[c0] - 1]) + int(final[-self.code_index[c1] - 1])
) % 2
parities.append(final_parities[::-1])

flips = {c: 0 for c in self.code_index}
for rparities in parities:
# see how many links around each code qubit detect a flip
link_count = {c: 0 for c in self.code_index}
for c0, a, c1 in self.links:
# we'll need to determine whether the as yet uncorrected parity
# checks from this round should be flipped, based on results
# from previous rounds
flip = (flips[c0] + flips[c1]) % 2
b = self.link_index[a]
for c in [c0, c1]:
link_count[c] += (rparities[b] + flip) % 2
# if it's all of them, assume a flip
for c in link_count:
if link_count[c] == self.degree[c]:
flips[c] = (flips[c] + 1) % 2
# modify the parities to remove the effect
for c0, a, c1 in self.links:
flip = (flips[c0] + flips[c1]) % 2
b = self.link_index[a]
rparities[b] = (rparities[b] + flip) % 2
# turn the results back into a string
new_string = ""
for rparities in parities[:-1][::-1]:
new_string += " " + "".join([str(p) for p in rparities][::-1])
final_string = [int(p) for p in string.split(" ", maxsplit=1)[0]]
for c, flip in flips.items():
b = self.code_index[c]
final_string[-b - 1] = (final_string[-b - 1] + flip) % 2
final_string = "".join([str(p) for p in final_string])

return final_string + new_string
2 changes: 1 addition & 1 deletion src/qiskit_qec/decoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@
from .circuit_matching_decoder import CircuitModelMatchingDecoder
from .repetition_decoder import RepetitionDecoder
from .three_bit_decoder import ThreeBitDecoder
from .hdrg_decoders import BravyiHaahDecoder, UnionFindDecoder, ClAYGDecoder
from .hdrg_decoders import BravyiHaahDecoder, UnionFindDecoder
97 changes: 94 additions & 3 deletions src/qiskit_qec/decoders/decoding_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,47 @@ def __init__(self, code, brute=False, graph=None):
if node.is_boundary:
self._logical_nodes.append(node)

self.update_attributes()

def update_attributes(self):
"""
Calculates properties of the graph used by `node_index` and `edge_in_graph`.
If `graph` is updated this method should called to update these properties.
"""
self._edge_set = set(self.graph.edge_list())
self._node_index = {}
for n, node in enumerate(self.graph.nodes()):
clean_node = copy.deepcopy(node)
clean_node.properties = {}
self._node_index[clean_node] = n

def node_index(self, node):
"""
Given a node of `graph`, returns the corrsponding index.
Args:
node (DecodingGraphNode): Node of the graph.
Returns:
n (int): Index corresponding to the node within the graph.
"""
clean_node = copy.deepcopy(node)
clean_node.properties = {}
return self._node_index[clean_node]

def edge_in_graph(self, edge):
"""
Given a pair of node indices for `graph`, determines whether
the edge exists within the graph.
Args:
edge (tuple): Pair of node indices for the graph.
Returns:
in_graph (bool): Whether the edge is within the graph.
"""
return edge in self._edge_set

def _make_syndrome_graph(self):
if not self.brute and hasattr(self.code, "_make_syndrome_graph"):
self.graph, self.hyperedges = self.code._make_syndrome_graph()
Expand Down Expand Up @@ -170,7 +211,7 @@ def get_error_probs(
error_nodes = set(self.code.string2nodes(string, logical=logical))

for node0 in error_nodes:
n0 = self.graph.nodes().index(node0)
n0 = self.node_index(node0)
av_v[n0] += counts[string]
for n1 in neighbours[n0]:
node1 = self.graph[n1]
Expand Down Expand Up @@ -341,15 +382,65 @@ def weight_fn(edge):
source = E[source_index]
target = E[target_index]
if target != source:
ns = self.graph.nodes().index(source)
nt = self.graph.nodes().index(target)
ns = self.node_index(source)
nt = self.node_index(target)
distance = distance_matrix[ns][nt]
if np.isfinite(distance):
qubits = list(set(source.qubits).intersection(target.qubits))
distance = int(distance)
E.add_edge(source_index, target_index, DecodingGraphEdge(qubits, distance))
return E

def clean_measurements(self, nodes: List):
"""
Removes pairs of nodes that obviously correspond to measurement errors
from a list of nodes.
Args:
nodes: A list of nodes.
Returns:
nodes: The input list of nodes, with pairs removed if they obviously
correspond to a measurement error.
"""

# order the nodes by where and when
node_pos = {}
for node in nodes:
if not node.is_boundary:
if node.index not in node_pos:
node_pos[node.index] = {}
node_pos[node.index][node.time] = self.node_index(node)
# find pairs corresponding to time-like edges
all_pairs = set()
for node_times in node_pos.values():
ts = list(node_times.keys())
ts.sort()
for j in range(len(ts) - 1):
if ts[j + 1] - ts[j] <= 2:
n0 = node_times[ts[j]]
n1 = node_times[ts[j + 1]]
if self.edge_in_graph((n0, n1)) or self.edge_in_graph((n1, n0)):
all_pairs.add((n0, n1))
# filter out those that share nodes
all_nodes = set()
common_nodes = set()
for pair in all_pairs:
for n in pair:
if n in all_nodes:
common_nodes.add(n)
all_nodes.add(n)
paired_ns = set()
for pair in all_pairs:
if pair[0] not in common_nodes:
if pair[1] not in common_nodes:
for n in pair:
paired_ns.add(n)
# return the nodes that were not paired
ns = set(self.node_index(node) for node in nodes)
unpaired_ns = ns.difference(paired_ns)
return [self.graph.nodes()[n] for n in unpaired_ns]


class CSSDecodingGraph:
"""
Expand Down
Loading

0 comments on commit c9cf30d

Please sign in to comment.