From 827cc08c52a460de759cf5f4772cbabdd9661f8b Mon Sep 17 00:00:00 2001 From: Grace Date: Thu, 14 Sep 2023 16:22:45 -0400 Subject: [PATCH] linted --- .pylintrc | 3 - src/qiskit_qec/circuits/repetition_code.py | 178 +++++++++++------- src/qiskit_qec/circuits/surface_code.py | 2 +- src/qiskit_qec/decoders/decoding_graph.py | 62 +++--- src/qiskit_qec/decoders/hdrg_decoders.py | 66 ++++--- src/qiskit_qec/linear/symplectic.py | 67 ++++--- .../utils/decoding_graph_attributes.py | 8 +- src/qiskit_qec/utils/stim_tools.py | 2 +- test/code_circuits/test_rep_codes.py | 89 ++++++--- test/union_find/test_clayg.py | 13 +- test/utils/test_visualization.py | 20 +- 11 files changed, 329 insertions(+), 181 deletions(-) diff --git a/.pylintrc b/.pylintrc index 2d44948f..7e2e1dc7 100644 --- a/.pylintrc +++ b/.pylintrc @@ -476,9 +476,6 @@ disable=raw-checker-failed, missing-raises-doc, logging-too-many-args, spelling, # way too noisy - no-self-use, # disabled as it is too verbose - bad-continuation, bad-whitespace # differences of opinion with black - # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option # multiple time (only on the command line, not in the configuration file where diff --git a/src/qiskit_qec/circuits/repetition_code.py b/src/qiskit_qec/circuits/repetition_code.py index 04b64cb0..cb9d4f2d 100644 --- a/src/qiskit_qec/circuits/repetition_code.py +++ b/src/qiskit_qec/circuits/repetition_code.py @@ -15,19 +15,19 @@ # pylint: disable=invalid-name """Generates circuits based on repetition codes.""" +from copy import deepcopy from typing import List, Optional, Tuple -from copy import deepcopy import numpy as np import rustworkx as rx - -from qiskit import ClassicalRegister, QuantumCircuit, QuantumRegister, transpile -from qiskit.circuit.library import XGate, RZGate -from qiskit.transpiler import PassManager, InstructionDurations +from qiskit import (ClassicalRegister, QuantumCircuit, QuantumRegister, + transpile) +from qiskit.circuit.library import RZGate, XGate +from qiskit.transpiler import InstructionDurations, PassManager from qiskit.transpiler.passes import DynamicalDecoupling from qiskit_qec.circuits.code_circuit import CodeCircuit -from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge +from qiskit_qec.utils import DecodingGraphEdge, DecodingGraphNode def _separate_string(string): @@ -87,7 +87,8 @@ def __init__( self.circuit = {} for log in ["0", "1"]: - self.circuit[log] = QuantumCircuit(self.link_qubit, self.code_qubit, name=log) + self.circuit[log] = QuantumCircuit( + self.link_qubit, self.code_qubit, name=log) self._xbasis = xbasis self._resets = resets @@ -184,7 +185,8 @@ def syndrome_measurement(self, final: bool = False, barrier: bool = False, delay """ barrier = barrier or self._barriers - self.link_bits.append(ClassicalRegister((self.d - 1), "round_" + str(self.T) + "_link_bit")) + self.link_bits.append(ClassicalRegister( + (self.d - 1), "round_" + str(self.T) + "_link_bit")) for log in ["0", "1"]: self.circuit[log].add_register(self.link_bits[-1]) @@ -196,14 +198,18 @@ def syndrome_measurement(self, final: bool = False, barrier: bool = False, delay self.circuit[log].h(self.link_qubit) for j in range(self.d - 1): if self._xbasis: - self.circuit[log].cx(self.link_qubit[j], self.code_qubit[j]) + self.circuit[log].cx( + self.link_qubit[j], self.code_qubit[j]) else: - self.circuit[log].cx(self.code_qubit[j], self.link_qubit[j]) + self.circuit[log].cx( + self.code_qubit[j], self.link_qubit[j]) for j in range(self.d - 1): if self._xbasis: - self.circuit[log].cx(self.link_qubit[j], self.code_qubit[j + 1]) + self.circuit[log].cx( + self.link_qubit[j], self.code_qubit[j + 1]) else: - self.circuit[log].cx(self.code_qubit[j + 1], self.link_qubit[j]) + self.circuit[log].cx( + self.code_qubit[j + 1], self.link_qubit[j]) if self._xbasis: self.circuit[log].h(self.link_qubit) @@ -211,7 +217,8 @@ def syndrome_measurement(self, final: bool = False, barrier: bool = False, delay if barrier: self.circuit[log].barrier() for j in range(self.d - 1): - self.circuit[log].measure(self.link_qubit[j], self.link_bits[self.T][j]) + self.circuit[log].measure( + self.link_qubit[j], self.link_bits[self.T][j]) # resets if self._resets and not final: @@ -245,15 +252,16 @@ def _process_string(self, string): measured_log = string[0] + " " + string[self.d - 1] if self._resets: - syndrome = string[self.d :] + syndrome = string[self.d:] else: # if there are no resets, results are cumulative and need to be separated - cumsyn_list = string[self.d :].split(" ") + cumsyn_list = string[self.d:].split(" ") syndrome_list = [] for tt, cum_syn in enumerate(cumsyn_list[0:-1]): syn = "" for j in range(len(cum_syn)): - syn += str(int(cumsyn_list[tt][j] != cumsyn_list[tt + 1][j])) + syn += str(int(cumsyn_list[tt][j] + != cumsyn_list[tt + 1][j])) syndrome_list.append(syn) syndrome_list.append(cumsyn_list[-1]) syndrome = " ".join(syndrome_list) @@ -261,7 +269,9 @@ def _process_string(self, string): # final syndrome deduced from final code qubit readout full_syndrome = "" for j in range(self.d - 1): - full_syndrome += "0" * (string[j] == string[j + 1]) + "1" * (string[j] != string[j + 1]) + full_syndrome += "0" * \ + (string[j] == string[j + 1]) + \ + "1" * (string[j] != string[j + 1]) # results from all other syndrome measurements then added full_syndrome = full_syndrome + syndrome @@ -310,7 +320,8 @@ def string2nodes(self, string, **kwargs): logical = "0" string = self._process_string(string) - separated_string = _separate_string(string) # [ , , ,...] + # [ , , ,...] + separated_string = _separate_string(string) nodes = [] # boundary nodes @@ -322,7 +333,8 @@ def string2nodes(self, string, **kwargs): bqubits = [self.css_x_logical[i]] else: bqubits = [self.css_z_logical[i]] - bnode = DecodingGraphNode(is_boundary=True, qubits=bqubits, index=bqec_index) + bnode = DecodingGraphNode( + is_boundary=True, qubits=bqubits, index=bqec_index) nodes.append(bnode) # bulk nodes @@ -335,7 +347,8 @@ def string2nodes(self, string, **kwargs): qubits = self.css_z_gauge_ops[qec_index] else: qubits = self.css_x_gauge_ops[qec_index] - node = DecodingGraphNode(time=syn_round, qubits=qubits, index=qec_index) + node = DecodingGraphNode( + time=syn_round, qubits=qubits, index=qec_index) nodes.append(node) return nodes @@ -434,10 +447,11 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): elem = self.css_z_boundary.index(qubits) else: elem = self.css_x_boundary.index(qubits) - node = DecodingGraphNode(is_boundary=True, qubits=qubits, index=elem) + node = DecodingGraphNode( + is_boundary=True, qubits=qubits, index=elem) flipped_logical_nodes.append(node) - if neutral and flipped_logical_nodes == []: + if neutral and not flipped_logical_nodes: break return neutral, flipped_logical_nodes, num_errors @@ -467,7 +481,8 @@ def partition_outcomes( for i, layer in enumerate(gauge_outcomes): for j, gauge_op in enumerate(layer): if i > 0: - gauge_outcomes[i][j] = (gauge_op + gauge_outcomes[i - 1][j]) % 2 + gauge_outcomes[i][j] = ( + gauge_op + gauge_outcomes[i - 1][j]) % 2 # assign outcomes to the correct gauge ops if round_schedule == "z": x_gauge_outcomes = [] @@ -574,14 +589,17 @@ def __init__( if run_202: self.links_202 = [] for link in self.links: - logical_overlap = {link[0], link[2]}.intersection(set(self.z_logicals)) + logical_overlap = {link[0], link[2] + }.intersection(set(self.z_logicals)) if not logical_overlap: self.links_202.append(link) num_links = len(self.links_202) if num_links > 0: self.rounds_per_link = int(np.floor(T / num_links)) - self.metabuffer = np.ceil((T - num_links * self.rounds_per_link) / 2) - self.roundbuffer = np.ceil((self.rounds_per_link - self.rounds_per_202) / 2) + self.metabuffer = np.ceil( + (T - num_links * self.rounds_per_link) / 2) + self.roundbuffer = np.ceil( + (self.rounds_per_link - self.rounds_per_202) / 2) if self.roundbuffer > 0: self.roundbuffer -= 1 self.run_202 = self.rounds_per_link >= self.rounds_per_202 @@ -605,7 +623,8 @@ def __init__( def _get_link_graph(self, max_dist=1): graph = rx.PyGraph() for link in self.links: - add_edge(graph, (link[0], link[2]), {"distance": 1, "link qubit": link[1]}) + add_edge(graph, (link[0], link[2]), { + "distance": 1, "link qubit": link[1]}) distance = rx.distance_matrix(graph) edges = graph.edge_list() for n0, node0 in enumerate(graph.nodes()): @@ -627,7 +646,8 @@ def _get_cycles(self): lg_edges = set(link_graph.edge_list()) lg_nodes = link_graph.nodes() cycles = rx.cycle_basis(link_graph) - cycle_dict = {(lg_nodes[edge[0]], lg_nodes[edge[1]]): list(edge) for edge in lg_edges} + cycle_dict = {(lg_nodes[edge[0]], lg_nodes[edge[1]]) + : list(edge) for edge in lg_edges} for cycle in cycles: edges = [] cl = len(cycle) @@ -685,7 +705,8 @@ def _get_coupling_graph(self, aux=None): add_edge(graph, (link[1], link[2]), {}) # we use max degree of the nodes as the edge weight, to delay bottlenecks for e, (n0, n1) in enumerate(graph.edge_list()): - graph.edges()[e]["weight"] = max(graph.degree(n0), graph.degree(n1)) + graph.edges()[e]["weight"] = max( + graph.degree(n0), graph.degree(n1)) return graph @@ -706,12 +727,14 @@ def weight_fn(edge): graph = self._get_coupling_graph(aux) # find a min weight matching, and then another that exlcudes the pairs from the first - matching = [rx.max_weight_matching(graph, max_cardinality=True, weight_fn=weight_fn)] + matching = [rx.max_weight_matching( + graph, max_cardinality=True, weight_fn=weight_fn)] cut_graph = deepcopy(graph) for n0, n1 in matching[0]: cut_graph.remove_edge(n0, n1) matching.append( - rx.max_weight_matching(cut_graph, max_cardinality=True, weight_fn=weight_fn) + rx.max_weight_matching( + cut_graph, max_cardinality=True, weight_fn=weight_fn) ) # rewrite the matchings to use nodes instead of indices, and to always place @@ -735,7 +758,8 @@ def weight_fn(edge): # add these matched pairs to the schedule for j in range(2): - schedule.append([pair for pair in matching[j] if pair[1] in completed]) + schedule.append( + [pair for pair in matching[j] if pair[1] in completed]) # update the list of auxilliaries for links yet to be paired aux = aux.difference(completed) @@ -798,7 +822,8 @@ def _preparation(self): # create the circuits and initialize the code qubits self.circuit = {} for basis in list({self.basis, self.basis[::-1]}): - self.circuit[basis] = QuantumCircuit(self.link_qubit, self.code_qubit, name=basis) + self.circuit[basis] = QuantumCircuit( + self.link_qubit, self.code_qubit, name=basis) if self.logical == "1": self.circuit[basis].x(self.code_qubit) self._basis_change(basis) @@ -857,7 +882,8 @@ def _get_202(self, t): neighbors = list(graph.incident_edges(n)) neighbors.remove(list(edges).index(tuple(ns))) qubit_l_nghbrs.append( - [graph.get_edge_data_by_index(ngbhr)["link qubit"] for ngbhr in neighbors] + [graph.get_edge_data_by_index( + ngbhr)["link qubit"] for ngbhr in neighbors] ) return tau, qubit_l_202, qubit_l_nghbrs @@ -869,7 +895,8 @@ def _syndrome_measurement(self, final: bool = False): """ self.link_bits.append( - ClassicalRegister(self.num_qubits[1], "round_" + str(self.T) + "_link_bit") + ClassicalRegister( + self.num_qubits[1], "round_" + str(self.T) + "_link_bit") ) tau, qubit_l_202, qubit_l_nghbrs = self._get_202(self.T) @@ -910,7 +937,8 @@ def _syndrome_measurement(self, final: bool = False): if self.resets and not final: for q_l in links_to_reset: if self.conditional_reset: - qc.x(self.link_qubit[q_l]).c_if(self.link_bits[self.T][q_l], 1) + qc.x(self.link_qubit[q_l]).c_if( + self.link_bits[self.T][q_l], 1) else: qc.reset(self.link_qubit[q_l]) @@ -924,7 +952,8 @@ def _syndrome_measurement(self, final: bool = False): # get the first listed neighbouring link control_link = self.links[self.link_index[qubit_l_nghbrs[j][0]]] # find the qubit on which it overlaps with the 202 - qubit_t = list(set(target_link).intersection(set(control_link)))[0] + qubit_t = list(set(target_link).intersection( + set(control_link)))[0] # and the qubit whose result controls the feedforward qubit_c = control_link[1] # get their indices @@ -933,7 +962,8 @@ def _syndrome_measurement(self, final: bool = False): # and the colour of the targeted qubit c = self.color[qubit_t] self._rotate(basis, c, self.code_qubit[q_t], True) - qc.x(self.code_qubit[q_t]).c_if(self.link_bits[self.T][q_c], 1) + qc.x(self.code_qubit[q_t]).c_if( + self.link_bits[self.T][q_c], 1) self._rotate(basis, c, self.code_qubit[q_t], False) # delay @@ -965,15 +995,16 @@ def _process_string(self, string): measured_log += string[self.num_qubits[0] - j - 1] + " " if self.resets: - syndrome = string[self.num_qubits[0] :] + syndrome = string[self.num_qubits[0]:] else: # if there are no resets, results are cumulative and need to be separated - cumsyn_list = string[self.num_qubits[0] :].split(" ") + cumsyn_list = string[self.num_qubits[0]:].split(" ") syndrome_list = [] for tt, cum_syn in enumerate(cumsyn_list[0:-1]): syn = "" for j in range(len(cum_syn)): - syn += str(int(cumsyn_list[tt][j] != cumsyn_list[tt + 1][j])) + syn += str(int(cumsyn_list[tt][j] + != cumsyn_list[tt + 1][j])) syndrome_list.append(syn) syndrome_list.append(cumsyn_list[-1]) syndrome = " ".join(syndrome_list) @@ -981,7 +1012,8 @@ def _process_string(self, string): # final syndrome deduced from final code qubit readout full_syndrome = "" for link in self.links: - q = [self.num_qubits[0] - 1 - self.code_index[link[j]] for j in [0, -1]] + q = [self.num_qubits[0] - 1 - self.code_index[link[j]] + for j in [0, -1]] full_syndrome += "0" * (string[q[0]] == string[q[1]]) + "1" * ( string[q[0]] != string[q[1]] ) @@ -1033,7 +1065,8 @@ def _process_string(self, string): dt = 1 # for those where we now have a dt, calculate the change if dt: - change = syndrome_list[-t - 1][j] != syndrome_list[-t - 1 + dt][j] + change = syndrome_list[-t - + 1][j] != syndrome_list[-t - 1 + dt][j] syndrome_changes += "0" * (not change) + "1" * change syndrome_changes += " " last_neighbors = all_neighbors.copy() @@ -1098,7 +1131,8 @@ def string2nodes(self, string, **kwargs) -> List[DecodingGraphNode]: qubits=code_qubits, index=elem_num, ) - node.properties["conjugate"] = ((tau % 2) == 1) and tau > 1 + node.properties["conjugate"] = ( + (tau % 2) == 1) and tau > 1 node.properties["link qubit"] = link_qubit nodes.append(node) return nodes @@ -1215,7 +1249,8 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): num_nodes[new_c] += 1 # if it is coloured, check the colour is correct else: - base_neutral = base_neutral and (node_color[nn] == (c + dc) % 2) + base_neutral = base_neutral and ( + node_color[nn] == (c + dc) % 2) for nn, c in newly_colored.items(): node_color[nn] = c if nn in ns_to_do: @@ -1224,7 +1259,8 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): # process is converged once one colour has stoppped growing # once ns_to_do is empty converged = (not ns_to_do) and ( - (num_nodes[0] == last_num[0] != 0) or (num_nodes[1] == last_num[1] != 0) + (num_nodes[0] == last_num[0] != 0) or ( + num_nodes[1] == last_num[1] != 0) ) fully_converged = converged and last_converged if not fully_converged: @@ -1239,7 +1275,8 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): else: min_color = (conv_color + 1) % 2 # calculate the number of nodes for the other - num_nodes[(min_color + 1) % 2] = link_graph.num_nodes() - num_nodes[min_color] + num_nodes[(min_color + 1) % + 2] = link_graph.num_nodes() - num_nodes[min_color] # get the set of min nodes min_ns = set() for n, c in node_color.items(): @@ -1280,7 +1317,8 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): flipped_logicals = set() # otherwise, report only needed logicals that aren't given else: - flipped_logicals = flipped_logicals.difference(given_logicals) + flipped_logicals = flipped_logicals.difference( + given_logicals) flipped_logical_nodes = [] for flipped_logical in flipped_logicals: @@ -1291,7 +1329,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): ) flipped_logical_nodes.append(node) - if neutral and flipped_logical_nodes == []: + if neutral and not flipped_logical_nodes: break else: @@ -1308,15 +1346,13 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): return neutral, flipped_logical_nodes, num_errors - def is_cluster_neutral(self, atypical_nodes): + def is_cluster_neutral(self, atypical_nodes: dict): """ Determines whether or not the cluster is neutral, meaning that one or more errors could have caused the set of atypical nodes (syndrome changes) passed to the method. Args: - atypical_nodes (dictionary in the form of the return value of string2nodes) - ignore_extra_boundary (bool): If `True`, undeeded boundary nodes are - ignored. + atypical_nodes: dictionary in the form of the return value of string2nodes """ neutral, logicals, _ = self.check_nodes(atypical_nodes) return neutral and not logicals @@ -1361,9 +1397,11 @@ def transpile(self, backend, echo=("X", "X"), echo_num=(2, 0)): dd_sequences.append([XGate()] * echo_num[j]) spacings.append(None) elif echo[j] == "XZX": - dd_sequences.append([XGate(), RZGate(np.pi), XGate()] * echo_num) + dd_sequences.append( + [XGate(), RZGate(np.pi), XGate()] * echo_num) d = 1.0 / (2 * echo_num - 1 + 1) - spacing = [d / 2] + ([0, d, d] * echo_num[j])[:-1] + [d / 2] + spacing = [d / 2] + \ + ([0, d, d] * echo_num[j])[:-1] + [d / 2] for _ in range(2): spacing[0] += 1 - sum(spacing) spacings.append(spacing) @@ -1399,7 +1437,9 @@ def transpile(self, backend, echo=("X", "X"), echo_num=(2, 0)): q = gate[1][0] t = gate[0].params[0] total_delay[0][q] += t - new_t = 16 * np.ceil((total_delay[0][q] - total_delay[1][q]) / 16) + new_t = 16 * \ + np.ceil( + (total_delay[0][q] - total_delay[1][q]) / 16) total_delay[1][q] += new_t gate[0].params[0] = new_t @@ -1515,7 +1555,8 @@ def get_error_coords(self, counts, decoding_graph, method="spitz"): for dt, pairs in enumerate(self.schedule): if pair in pairs or tuple(pair) in pairs: dts.append(dt) - time = [max(0, node0.time - 1 + (max(dts) + 1) / round_length)] + time = [ + max(0, node0.time - 1 + (max(dts) + 1) / round_length)] time.append(node0.time + min(dts) / round_length) # error during a round else: @@ -1533,8 +1574,10 @@ def get_error_coords(self, counts, decoding_graph, method="spitz"): dts.append(dt) # use to define fractional time if dts[0] < dts[1]: - time = [node_pair[1].time + (dts[0] + 1) / round_length] - time.append(node_pair[1].time + dts[1] / round_length) + time = [node_pair[1].time + + (dts[0] + 1) / round_length] + time.append( + node_pair[1].time + dts[1] / round_length) else: # impossible cases get no valid time time = [] @@ -1542,25 +1585,30 @@ def get_error_coords(self, counts, decoding_graph, method="spitz"): # measurement error assert node0.time != node1.time and node0.qubits == node1.qubits qubit = node0.properties["link qubit"] - time = [node0.time, node0.time + (round_length - 1) / round_length] + time = [node0.time, node0.time + + (round_length - 1) / round_length] time.sort() else: # detected only by one stabilizer - boundary_qubits = list(set(node0.qubits).intersection(z_logicals)) + boundary_qubits = list( + set(node0.qubits).intersection(z_logicals)) # for the case of boundary stabilizers if boundary_qubits: qubit = boundary_qubits[0] pair = [qubit, node0.properties["link qubit"]] for dt, pairs in enumerate(self.schedule): if pair in pairs or tuple(pair) in pairs: - time = [max(0, node0.time - 1 + (dt + 1) / round_length)] + time = [ + max(0, node0.time - 1 + (dt + 1) / round_length)] time.append(node0.time + dt / round_length) else: - qubit = tuple(node0.qubits + [node0.properties["link qubit"]]) - time = [node0.time, node0.time + (round_length - 1) / round_length] + qubit = tuple( + node0.qubits + [node0.properties["link qubit"]]) + time = [node0.time, node0.time + + (round_length - 1) / round_length] - if time != []: # only record if not nan + if time: # only record if not nan if (qubit, time[0], time[1]) not in error_coords: error_coords[qubit, time[0], time[1]] = {} error_coords[qubit, time[0], time[1]][n0, n1] = prob diff --git a/src/qiskit_qec/circuits/surface_code.py b/src/qiskit_qec/circuits/surface_code.py index dc800dcb..f6f6a8b1 100644 --- a/src/qiskit_qec/circuits/surface_code.py +++ b/src/qiskit_qec/circuits/surface_code.py @@ -478,7 +478,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): p = y else: p = x - num_errors = min(num_errors, min(p + 1, self.d - p)) + num_errors = min(num_errors, p + 1, self.d - p) flipped_logicals = {1 - int(p < (self.d - 1) / 2)} # if unneeded logical zs are given, cluster is not neutral diff --git a/src/qiskit_qec/decoders/decoding_graph.py b/src/qiskit_qec/decoders/decoding_graph.py index 774cf361..58b4dce0 100644 --- a/src/qiskit_qec/decoders/decoding_graph.py +++ b/src/qiskit_qec/decoders/decoding_graph.py @@ -19,13 +19,14 @@ """ import itertools import logging -from typing import List, Tuple +from typing import List, Tuple, Union import numpy as np import rustworkx as rx + from qiskit_qec.analysis.faultenumerator import FaultEnumerator from qiskit_qec.exceptions import QiskitQECError -from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge +from qiskit_qec.utils import DecodingGraphEdge, DecodingGraphNode class DecodingGraph: @@ -98,7 +99,8 @@ def _make_syndrome_graph(self): n1 = graph.nodes().index(target) qubits = [] if not (source.is_boundary and target.is_boundary): - qubits = list(set(source.qubits).intersection(target.qubits)) + qubits = list( + set(source.qubits).intersection(target.qubits)) if not qubits: continue if ( @@ -161,7 +163,8 @@ def get_error_probs( for string in counts: # list of i for which v_i=1 - error_nodes = set(self.code.string2nodes(string, logical=logical)) + error_nodes = set(self.code.string2nodes( + string, logical=logical)) for node0 in error_nodes: n0 = self.graph.nodes().index(node0) @@ -191,9 +194,11 @@ def get_error_probs( boundary.append(n0) else: if (1 - 2 * av_xor[n0, n1]) != 0: - x = (av_vv[n0, n1] - av_v[n0] * av_v[n1]) / (1 - 2 * av_xor[n0, n1]) + x = (av_vv[n0, n1] - av_v[n0] * av_v[n1]) / \ + (1 - 2 * av_xor[n0, n1]) if x < 0.25: - error_probs[n0, n1] = max(0, 0.5 - np.sqrt(0.25 - x)) + error_probs[n0, n1] = max( + 0, 0.5 - np.sqrt(0.25 - x)) else: error_probs[n0, n1] = np.nan else: @@ -222,7 +227,8 @@ def get_error_probs( for edge in self.graph.edge_list() } for string in counts: - error_nodes = set(self.code.string2nodes(string, logical=logical)) + error_nodes = set(self.code.string2nodes( + string, logical=logical)) for edge in self.graph.edge_list(): element = "" for j in range(2): @@ -293,12 +299,13 @@ def weight_syndrome_graph(self, counts, method: str = METHOD_SPITZ): edge_data.weight = w self.graph.update_edge(edge[0], edge[1], edge_data) - def make_error_graph(self, data, all_logicals=True): + def make_error_graph(self, data: Union[str, List], all_logicals=True): """Returns error graph. Args: data: Either an ouput string of the code, or a list of nodes for the code. + all_logicals(bool): Whether to do all logicals Returns: The subgraph of graph which corresponds to the non-trivial @@ -322,7 +329,8 @@ def make_error_graph(self, data, all_logicals=True): def weight_fn(edge): return int(edge.weight) - distance_matrix = rx.graph_floyd_warshall_numpy(self.graph, weight_fn=weight_fn) + distance_matrix = rx.graph_floyd_warshall_numpy( + self.graph, weight_fn=weight_fn) for source_index in E.node_indexes(): for target_index in E.node_indexes(): @@ -333,9 +341,11 @@ def weight_fn(edge): nt = self.graph.nodes().index(target) distance = distance_matrix[ns][nt] if np.isfinite(distance): - qubits = list(set(source.qubits).intersection(target.qubits)) + qubits = list( + set(source.qubits).intersection(target.qubits)) distance = int(distance) - E.add_edge(source_index, target_index, DecodingGraphEdge(qubits, distance)) + E.add_edge(source_index, target_index, + DecodingGraphEdge(qubits, distance)) return E @@ -367,7 +377,8 @@ def __init__( self.round_schedule = round_schedule self.basis = basis - self.layer_types = self._layer_types(self.blocks, self.round_schedule, self.basis) + self.layer_types = self._layer_types( + self.blocks, self.round_schedule, self.basis) self._decoding_graph() @@ -442,7 +453,8 @@ def _decoding_graph(self): idx += 1 for index, supp in enumerate(boundary): # Add optional is_boundary property for pymatching - node = DecodingGraphNode(is_boundary=True, qubits=supp, index=index) + node = DecodingGraphNode( + is_boundary=True, qubits=supp, index=index) node.properties["highlighted"] = False graph.add_node(node) logging.debug("boundary %d t=%d %s", idx, time, supp) @@ -478,9 +490,11 @@ def _decoding_graph(self): edge.properties["highlighted"] = False edge.properties["measurement_error"] = 0 graph.add_edge( - idxmap[(time, tuple(op_g))], idxmap[(time, tuple(op_h))], edge + idxmap[(time, tuple(op_g))], idxmap[( + time, tuple(op_h))], edge ) - logging.debug("spacelike t=%d (%s, %s)", time, op_g, op_h) + logging.debug("spacelike t=%d (%s, %s)", + time, op_g, op_h) logging.debug( " qubits %s", [com[0]], @@ -497,8 +511,10 @@ def _decoding_graph(self): edge = DecodingGraphEdge(qubits=[], weight=0) edge.properties["highlighted"] = False edge.properties["measurement_error"] = 0 - graph.add_edge(idxmap[(time, tuple(bound_g))], idxmap[(time, tuple(bound_h))], edge) - logging.debug("spacelike boundary t=%d (%s, %s)", time, bound_g, bound_h) + graph.add_edge(idxmap[(time, tuple(bound_g))], + idxmap[(time, tuple(bound_h))], edge) + logging.debug("spacelike boundary t=%d (%s, %s)", + time, bound_g, bound_h) # Add (space)time-like edges from t to t-1 # By construction, the qubit sets of pairs of vertices at graph and T @@ -543,9 +559,11 @@ def _decoding_graph(self): idxmap[(time, tuple(op_g))], edge, ) - logging.debug("timelike t=%d (%s, %s)", time, op_g, op_h) + logging.debug( + "timelike t=%d (%s, %s)", time, op_g, op_h) else: # Case (b) - edge = DecodingGraphEdge(qubits=[com[0]], weight=1) + edge = DecodingGraphEdge( + qubits=[com[0]], weight=1) edge.properties["highlighted"] = False edge.properties["measurement_error"] = 1 graph.add_edge( @@ -553,7 +571,8 @@ def _decoding_graph(self): idxmap[(time, tuple(op_g))], edge, ) - logging.debug("spacetime hook t=%d (%s, %s)", time, op_g, op_h) + logging.debug( + "spacetime hook t=%d (%s, %s)", time, op_g, op_h) logging.debug(" qubits %s", [com[0]]) # Add a single time-like edge between boundary vertices at # time t-1 and t @@ -561,7 +580,8 @@ def _decoding_graph(self): edge.properties["highlighted"] = False edge.properties["measurement_error"] = 0 graph.add_edge( - idxmap[(time - 1, tuple(boundary[0]))], idxmap[(time, tuple(boundary[0]))], edge + idxmap[(time - 1, tuple(boundary[0])) + ], idxmap[(time, tuple(boundary[0]))], edge ) logging.debug("boundarylink t=%d", time) diff --git a/src/qiskit_qec/decoders/hdrg_decoders.py b/src/qiskit_qec/decoders/hdrg_decoders.py index 60a788b7..6c1ed19c 100644 --- a/src/qiskit_qec/decoders/hdrg_decoders.py +++ b/src/qiskit_qec/decoders/hdrg_decoders.py @@ -16,15 +16,16 @@ """Hard decision renormalization group decoders.""" +from abc import ABC from copy import copy, deepcopy from dataclasses import dataclass from typing import Dict, List, Set, Tuple -from abc import ABC -from rustworkx import connected_components, distance_matrix, PyGraph + +from rustworkx import PyGraph, connected_components, distance_matrix from qiskit_qec.circuits.repetition_code import ArcCircuit from qiskit_qec.decoders.decoding_graph import DecodingGraph -from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge +from qiskit_qec.utils import DecodingGraphEdge, DecodingGraphNode class ClusteringDecoder(ABC): @@ -87,7 +88,8 @@ def get_corrections(self, string, clusters): cluster_logicals[c] = z_logicals # get the net effect on each logical - net_z_logicals = {z_logical[0]: 0 for z_logical in self.measured_logicals} + net_z_logicals = { + z_logical[0]: 0 for z_logical in self.measured_logicals} for c, z_logicals in cluster_logicals.items(): for z_logical in self.measured_logicals: if z_logical[0] in z_logicals: @@ -99,7 +101,8 @@ def get_corrections(self, string, clusters): string = string.split(" ")[0] for z_logical in self.measured_logicals: raw_logical = int(string[-1 - self.code_index[z_logical[0]]]) - corrected_logical = (raw_logical + net_z_logicals[z_logical[0]]) % 2 + corrected_logical = ( + raw_logical + net_z_logicals[z_logical[0]]) % 2 corrected_z_logicals.append(corrected_logical) return corrected_z_logicals @@ -168,7 +171,8 @@ def _cluster(self, ns, dist_max): def _get_boundary_nodes(self): boundary_nodes = [] for element, z_logical in enumerate(self.measured_logicals): - node = DecodingGraphNode(is_boundary=True, qubits=z_logical, index=element) + node = DecodingGraphNode( + is_boundary=True, qubits=z_logical, index=element) if isinstance(self.code, ArcCircuit): node.properties["link qubit"] = None boundary_nodes.append(node) @@ -188,7 +192,8 @@ def cluster(self, nodes): # get indices for nodes and boundary nodes dg = self.decoding_graph.graph ns = set(dg.nodes().index(node) for node in nodes) - bns = set(dg.nodes().index(node) for node in self._get_boundary_nodes()) + bns = set(dg.nodes().index(node) + for node in self._get_boundary_nodes()) dist_max = 0 final_clusters = {} @@ -327,14 +332,16 @@ def process(self, string: str): if self.use_peeling: self.graph = deepcopy(self.decoding_graph.graph) - highlighted_nodes = self.code.string2nodes(string, all_logicals=True) + highlighted_nodes = self.code.string2nodes( + string, all_logicals=True) # call cluster to do the clustering, but actually use the peeling form self.cluster(highlighted_nodes) clusters = self._clusters4peeling # determine the net logical z - net_z_logicals = {tuple(z_logical): 0 for z_logical in self.measured_logicals} + net_z_logicals = { + tuple(z_logical): 0 for z_logical in self.measured_logicals} for cluster_nodes, _ in clusters: erasure = self.graph.subgraph(cluster_nodes) flipped_qubits = self.peeling(erasure) @@ -350,7 +357,8 @@ def process(self, string: str): raw_logicals = self.code.string2raw_logicals(string) for j, z_logical in enumerate(self.measured_logicals): raw_logical = int(raw_logicals[j]) - corrected_logical = (raw_logical + net_z_logicals[tuple(z_logical)]) % 2 + corrected_logical = ( + raw_logical + net_z_logicals[tuple(z_logical)]) % 2 corrected_z_logicals.append(corrected_logical) return corrected_z_logicals else: @@ -359,16 +367,13 @@ def process(self, string: str): clusters = self.cluster(nodes) return self.get_corrections(string, clusters) - def cluster(self, nodes): + def cluster(self, nodes: List): """ Create clusters using the union-find algorithm. Args: nodes (List): List of non-typical nodes in the syndrome graph, of the type produced by `string2nodes`. - standard_form (Bool): Whether to use the standard form of - the clusters for clustering decoders, or the form used internally - by the class. Returns: clusters (dict): Dictionary with the indices of @@ -396,7 +401,8 @@ def cluster(self, nodes): clusters = {} for c, cluster in self.clusters.items(): # determine which nodes exactly are in the neutral cluster - neutral_nodes = list(cluster.atypical_nodes | cluster.boundary_nodes) + neutral_nodes = list(cluster.atypical_nodes | + cluster.boundary_nodes) # put them in the required dict for n in neutral_nodes: clusters[n] = c @@ -407,7 +413,8 @@ def cluster(self, nodes): if not cluster.atypical_nodes: continue self._clusters4peeling.append( - (list(cluster.nodes), list(cluster.atypical_nodes | cluster.boundary_nodes)) + (list(cluster.nodes), list( + cluster.atypical_nodes | cluster.boundary_nodes)) ) return clusters @@ -426,7 +433,8 @@ def find(self, u: int) -> int: if self.graph[u].properties["root"] == u: return self.graph[u].properties["root"] - self.graph[u].properties["root"] = self.find(self.graph[u].properties["root"]) + self.graph[u].properties["root"] = self.find( + self.graph[u].properties["root"]) return self.graph[u].properties["root"] def _create_new_cluster(self, node_index): @@ -435,11 +443,13 @@ def _create_new_cluster(self, node_index): self.odd_cluster_roots.insert(0, node_index) boundary_edges = [] for edge_index, neighbour, data in self.neighbouring_edges(node_index): - boundary_edges.append(BoundaryEdge(edge_index, node_index, neighbour, data)) + boundary_edges.append(BoundaryEdge( + edge_index, node_index, neighbour, data)) self.clusters[node_index] = UnionFindDecoderCluster( boundary=boundary_edges, fully_grown_edges=set(), - atypical_nodes=set([node_index]) if not node.is_boundary else set([]), + atypical_nodes=set( + [node_index]) if not node.is_boundary else set([]), boundary_nodes=set([node_index]) if node.is_boundary else set([]), nodes=set([node_index]), size=1, @@ -606,7 +616,8 @@ def peeling(self, erasure: PyGraph) -> List[int]: edges = set() for edge in tree.edges[::-1]: endpoints = erasure.get_edge_endpoints_by_index(edge) - pendant_vertex = endpoints[0] if not tree.vertices[endpoints[0]] else endpoints[1] + pendant_vertex = endpoints[0] if not tree.vertices[endpoints[0] + ] else endpoints[1] tree_vertex = endpoints[0] if pendant_vertex == endpoints[1] else endpoints[1] tree.vertices[tree_vertex].remove(edge) if erasure[pendant_vertex].properties["syndrome"]: @@ -685,7 +696,8 @@ def process(self, string: str): edge.properties["fully_grown"] = False string = "".join([str(c) for c in string[::-1]]) - output = [int(bit) for bit in list(string.split(" ", maxsplit=self.code.d)[0])][::-1] + output = [int(bit) for bit in list( + string.split(" ", maxsplit=self.code.d)[0])][::-1] highlighted_nodes = self.code.string2nodes(string, all_logicals=True) if not highlighted_nodes: return output # There's nothing for us to do here @@ -696,7 +708,8 @@ def process(self, string: str): flattened_highlighted_nodes: List[DecodingGraphNode] = [] for highlighted_node in highlighted_nodes: highlighted_node.time = 0 - flattened_highlighted_nodes.append(self.graph.nodes().index(highlighted_node)) + flattened_highlighted_nodes.append( + self.graph.nodes().index(highlighted_node)) for cluster_nodes, cluster_atypical_nodes in clusters: if not cluster_nodes: @@ -725,7 +738,8 @@ def cluster(self, nodes): self.clusters: Dict[int, UnionFindDecoderCluster] = {} self.odd_cluster_roots = [] - times: List[List[DecodingGraphNode]] = [[] for _ in range(self.code.T + 1)] + times: List[List[DecodingGraphNode]] = [[] + for _ in range(self.code.T + 1)] boundaries = [] for node in deepcopy(nodes): if node.is_boundary: @@ -757,7 +771,8 @@ def cluster(self, nodes): clusters = {} for c, cluster in enumerate(neutral_clusters): # determine which nodes exactly are in the neutral cluster - neutral_nodes = list(cluster.atypical_nodes | cluster.boundary_nodes) + neutral_nodes = list(cluster.atypical_nodes | + cluster.boundary_nodes) # put them in the required dict for n in neutral_nodes: clusters[n] = c @@ -765,7 +780,8 @@ def cluster(self, nodes): neutral_cluster_nodes: List[List[int]] = [] for cluster in neutral_clusters: neutral_cluster_nodes.append( - (list(cluster.nodes), list(cluster.atypical_nodes | cluster.boundary_nodes)) + (list(cluster.nodes), list( + cluster.atypical_nodes | cluster.boundary_nodes)) ) self._clusters4peeling = neutral_cluster_nodes diff --git a/src/qiskit_qec/linear/symplectic.py b/src/qiskit_qec/linear/symplectic.py index 80fe38fe..7222a8c1 100644 --- a/src/qiskit_qec/linear/symplectic.py +++ b/src/qiskit_qec/linear/symplectic.py @@ -13,12 +13,11 @@ """Symplectic functions.""" from collections import deque -from typing import List, Any, Tuple -from typing import Union, Optional +from typing import Any, List, Optional, Tuple, Union import numpy as np - from qiskit import QiskitError + from qiskit_qec.linear import matrix as mt @@ -93,7 +92,8 @@ def symplectic_product(mat1: np.ndarray, mat2: np.ndarray) -> Union[int, np.ndar mat2_np_array = np.array(mat2, dtype=np.int8) if not is_symplectic_form(mat1) or not is_symplectic_form(mat2): - raise QiskitError("Input matrices/vectors must be GF(2) symplectic matrices/vectors") + raise QiskitError( + "Input matrices/vectors must be GF(2) symplectic matrices/vectors") if not mat1_np_array.ndim == mat2_np_array.ndim: raise QiskitError( @@ -325,14 +325,16 @@ def make_commute_hyper( _make_commute_hyper """ if not (is_symplectic_form(a) and is_symplectic_form(x) and is_symplectic_form(z)): - raise QiskitError("Input matrices/vectors must be GF(2) symplectic matrices/vectors") + raise QiskitError( + "Input matrices/vectors must be GF(2) symplectic matrices/vectors") def make_list(srange): if srange is not None: try: srange = list(srange) except TypeError as terror: - raise QiskitError(f"Input range {srange} is not iterable") from terror + raise QiskitError( + f"Input range {srange} is not iterable") from terror return srange @@ -349,7 +351,8 @@ def make_list(srange): z = np.atleast_2d(np.array(z)) if not a.shape[1] == x.shape[1] == z.shape[1]: - raise QiskitError("Input matrices/vectors must have the same number of columns/length") + raise QiskitError( + "Input matrices/vectors must have the same number of columns/length") return _make_commute_hyper(a, x, z, arange, xrange, zrange, squeeze) @@ -618,7 +621,8 @@ def build_hyper_partner(matrix, index: int) -> np.ndarray: # matrix -> all associated operators must commute if not all_commute(matrix): - raise QiskitError("Input matrix must represent a set of commuting operators") + raise QiskitError( + "Input matrix must represent a set of commuting operators") rank_ = mt.rank(matrix) if rank_ != matrix.shape[0]: @@ -628,7 +632,8 @@ def build_hyper_partner(matrix, index: int) -> np.ndarray: ) if index not in range(matrix.shape[1] >> 1): - raise QiskitError(f"Input index out or range: {index}>={matrix.shape[1]>>1}") + raise QiskitError( + f"Input index out or range: {index}>={matrix.shape[1]>>1}") return _build_hyper_partner(matrix, index) @@ -753,7 +758,8 @@ def symplectic_gram_schmidt( x = np.atleast_2d(x) x = list(x) if not is_symplectic_vector_form(x[0]): - raise QiskitError("Input hyperbolic array x is not a GF(2) sympletic matrix") + raise QiskitError( + "Input hyperbolic array x is not a GF(2) sympletic matrix") if z is None: z = [] @@ -761,7 +767,8 @@ def symplectic_gram_schmidt( z = np.atleast_2d(z) z = list(z) if not is_symplectic_vector_form(z[0]): - raise QiskitError("Input hyperbolic array z is not a GF(2) sympletic matrix") + raise QiskitError( + "Input hyperbolic array z is not a GF(2) sympletic matrix") if not len(x) == len(z): raise QiskitError("Input hyperbolic arrays have different dimensions") @@ -771,7 +778,8 @@ def symplectic_gram_schmidt( if x != []: if not is_hyper_form(x, z): - raise QiskitError("Input hyperbolic matrices do not represent a hyperbolic basis") + raise QiskitError( + "Input hyperbolic matrices do not represent a hyperbolic basis") return _symplectic_gram_schmidt(a, x, z) @@ -834,7 +842,7 @@ def _symplectic_gram_schmidt( # Revove elem_p from a_view temp_view = a_view[:-1] - temp_view[index:] = a_view[index + 1 :] + temp_view[index:] = a_view[index + 1:] a_view = temp_view a_view = make_commute_hyper(a_view, elem, elem_p) @@ -1496,7 +1504,8 @@ def _make_hyperbolic( while center_size > 0: hop = _build_hyper_partner(center_[:center_size], center_size - 1) # TODO: Change the use of make_commute_hyper to _make_commute_hyper - hop = make_commute_hyper(hop, x, z, xrange=range(hyper_size), zrange=range(hyper_size)) + hop = make_commute_hyper(hop, x, z, xrange=range( + hyper_size), zrange=range(hyper_size)) # hop = _make_element_commute_with_hyper_pairs( # hop, x, z, range(hyper_size), range(hyper_size) # ) @@ -1657,7 +1666,8 @@ def hyperbolic_basis_for_pauli_group( else: if matrix is None: if n is None: - raise QiskitError("If matrix, x and z are None then n must be provided") + raise QiskitError( + "If matrix, x and z are None then n must be provided") zero = np.zeros(shape=(n, n), dtype=np.bool_) x = mt.augment_mat(zero, "left") z = mt.augment_mat(zero, "right") @@ -1803,7 +1813,8 @@ def remove_hyper_elements_from_hyper_form( if center_ is not None: center_ = np.atleast_2d(np.array(center_)) if not is_symplectic_matrix_form(center_): - raise QiskitError("Input center is not a GF(2) symplectiv matrix/vector") + raise QiskitError( + "Input center is not a GF(2) symplectiv matrix/vector") if not x.shape[1] == center_.shape[1]: raise QiskitError( "x and z must have the same size in the second \ @@ -1888,7 +1899,8 @@ def min_generating( Args: matrix (Optional[np.ndarray]): Input GF(2) symplectic matrix - x, z (Optional[np.ndarray]): Input hyperbolic set - pair of GF(2) symplectic matrices + x (Optional[np.ndarray]): Input hyperbolic set - pair of GF(2) symplectic matrices + z (Optional[np.ndarray]): Input hyperbolic set - pair of GF(2) symplectic matrices Raises: QiskitError: An input matrix is required @@ -1969,7 +1981,8 @@ def _min_generating_matrix(matrix: np.ndarray) -> np.ndarray: return matrix posns = np.flatnonzero(heads) - ext_matrix = np.zeros(shape=(posns.shape[0], matrix.shape[1]), dtype=np.bool_) + ext_matrix = np.zeros( + shape=(posns.shape[0], matrix.shape[1]), dtype=np.bool_) for k, index in enumerate(posns): ext_matrix[k] = matrix[index] return ext_matrix @@ -1991,8 +2004,9 @@ def _min_generating_matrix_xz(matrix, x, z) -> np.ndarray: heads, _, _, rank_ = mt.rref_complete(tmp_matrix) if rank_ < tmp_matrix.shape[0]: # Since (x,z) has already been reduced any removal will appear in the matrix part - posns = np.flatnonzero(heads[2 * x.shape[0] :]) - ext_matrix = np.zeros(shape=(posns.shape[0], matrix.shape[1]), dtype=np.bool_) + posns = np.flatnonzero(heads[2 * x.shape[0]:]) + ext_matrix = np.zeros( + shape=(posns.shape[0], matrix.shape[1]), dtype=np.bool_) for k, index in enumerate(posns): ext_matrix[k] = matrix[index] matrix = ext_matrix @@ -2075,7 +2089,8 @@ def normalizer( raise QiskitError("x and z must have the same shape") if not matrix.shape[1] == x.shape[1]: - raise QiskitError("All inputs must have the same number of columns/length") + raise QiskitError( + "All inputs must have the same number of columns/length") if not is_center(matrix, np.vstack((matrix, x, z))): raise QiskitError( @@ -2106,7 +2121,8 @@ def _normalizer_abelian_group( matrix_ext = _basis_for_pauli_group(matrix) center_, x, z = _symplectic_gram_schmidt(matrix_ext, [], []) - center_, x, z = _remove_hyper_elements_from_hyper_form(center_, x, z, list(range(dist_center))) + center_, x, z = _remove_hyper_elements_from_hyper_form( + center_, x, z, list(range(dist_center))) if center_.shape[0] > 1: center_ = center_[np.where(center_.any(axis=1))[0]] @@ -2143,7 +2159,8 @@ def _normalizer_group_preserve( x, z = _make_hyperbolic(center_, x, z) matrix_ext = _basis_for_pauli_group(np.vstack((x, z))) - matrix_ext = make_commute_hyper(matrix_ext, x, z, range(x.shape[0] << 1, matrix_ext.shape[0])) + matrix_ext = make_commute_hyper( + matrix_ext, x, z, range(x.shape[0] << 1, matrix_ext.shape[0])) # matrix_ext = _make_elements_commute_with_hyper_pairs( # matrix_ext, # range(x.shape[0] << 1, matrix_ext.shape[0]), @@ -2152,9 +2169,9 @@ def _normalizer_group_preserve( # z, # range(z.shape[0]), # ) - matrix = matrix_ext[x.shape[0] << 1 :] + matrix = matrix_ext[x.shape[0] << 1:] lx = [item.copy() for item in matrix_ext[: x.shape[0]]] - lz = [item.copy() for item in matrix_ext[x.shape[0] : x.shape[0] << 1]] + lz = [item.copy() for item in matrix_ext[x.shape[0]: x.shape[0] << 1]] center_, x, z = _symplectic_gram_schmidt(matrix, lx, lz) indices = list(range(gauge_degree, gauge_degree + center_size)) return _remove_hyper_elements_from_hyper_form(center_, x, z, indices) diff --git a/src/qiskit_qec/utils/decoding_graph_attributes.py b/src/qiskit_qec/utils/decoding_graph_attributes.py index c650e72c..85699053 100644 --- a/src/qiskit_qec/utils/decoding_graph_attributes.py +++ b/src/qiskit_qec/utils/decoding_graph_attributes.py @@ -57,11 +57,13 @@ def __getitem__(self, key): return self.properties[key] else: raise QiskitQECError( - "'" + str(key) + "'" + " is not an an attribute or property of the node." + "'" + str(key) + "'" + + " is not an an attribute or property of the node." ) def get(self, key, _): """A dummy docstring.""" + # pylint: disable=unnecessary-dunder-call return self.__getitem__(key) def __setitem__(self, key, value): @@ -117,11 +119,13 @@ def __getitem__(self, key): return self.properties[key] else: raise QiskitQECError( - "'" + str(key) + "'" + " is not an an attribute or property of the edge." + "'" + str(key) + "'" + + " is not an an attribute or property of the edge." ) def get(self, key, _): """A dummy docstring.""" + # pylint: disable=unnecessary-dunder-call return self.__getitem__(key) def __setitem__(self, key, value): diff --git a/src/qiskit_qec/utils/stim_tools.py b/src/qiskit_qec/utils/stim_tools.py index b5ca721c..17015ab4 100644 --- a/src/qiskit_qec/utils/stim_tools.py +++ b/src/qiskit_qec/utils/stim_tools.py @@ -144,7 +144,7 @@ def get_counts_via_stim( """Returns a qiskit compatible dictionary of measurement outcomes Args: - circuit: Qiskit circuit compatible with `get_stim_circuits` or list thereof. + circuits: Qiskit circuit compatible with `get_stim_circuits` or list thereof. shots: Number of samples to be generated. noise_model: Pauli noise model for any additional noise to be applied. diff --git a/test/code_circuits/test_rep_codes.py b/test/code_circuits/test_rep_codes.py index 763fe70d..a84eb9c0 100644 --- a/test/code_circuits/test_rep_codes.py +++ b/test/code_circuits/test_rep_codes.py @@ -16,20 +16,23 @@ """Run codes and decoders.""" -import unittest import itertools +import unittest from random import choices from qiskit import Aer, QuantumCircuit, execute from qiskit.providers.fake_provider import FakeJakarta from qiskit_aer.noise import NoiseModel from qiskit_aer.noise.errors import depolarizing_error -from qiskit_qec.circuits.repetition_code import RepetitionCodeCircuit as RepetitionCode + +from qiskit_qec.analysis.faultenumerator import FaultEnumerator from qiskit_qec.circuits.repetition_code import ArcCircuit +from qiskit_qec.circuits.repetition_code import \ + RepetitionCodeCircuit as RepetitionCode from qiskit_qec.decoders.decoding_graph import DecodingGraph +from qiskit_qec.decoders.hdrg_decoders import (BravyiHaahDecoder, + UnionFindDecoder) from qiskit_qec.utils import DecodingGraphNode -from qiskit_qec.analysis.faultenumerator import FaultEnumerator -from qiskit_qec.decoders.hdrg_decoders import BravyiHaahDecoder, UnionFindDecoder def get_syndrome(code, noise_model, shots=1024): @@ -55,7 +58,8 @@ def get_noise(p_meas, p_gate): error_gate2 = error_gate1.tensor(error_gate1) noise_model = NoiseModel() - noise_model.add_all_qubit_readout_error([[1 - p_meas, p_meas], [p_meas, 1 - p_meas]]) + noise_model.add_all_qubit_readout_error( + [[1 - p_meas, p_meas], [p_meas, 1 - p_meas]]) noise_model.add_all_qubit_quantum_error(error_gate1, ["h"]) noise_model.add_all_qubit_quantum_error(error_gate2, ["cx"]) @@ -91,7 +95,7 @@ def single_error_test( temp_qc.name = str((j, qubit, error)) temp_qc.data = qc.data[0:j] getattr(temp_qc, error)(qubit) - temp_qc.data += qc.data[j : depth + 1] + temp_qc.data += qc.data[j: depth + 1] circuit_name[(j, qubit, error)] = temp_qc.name error_circuit[temp_qc.name] = temp_qc @@ -125,7 +129,8 @@ def test_string2nodes_1(self): s0 = "0 0 01 00 01" s1 = "1 1 01 00 01" self.assertTrue( - code.string2nodes(s0, logical="0") == code.string2nodes(s1, logical="1"), + code.string2nodes(s0, logical="0") == code.string2nodes( + s1, logical="1"), "Error: Incorrect nodes from results string", ) @@ -197,8 +202,10 @@ def test_weight(self): + "'." ) p = dec.get_error_probs(test_results, method=method) - n0 = dec.graph.nodes().index(DecodingGraphNode(time=0, qubits=[0, 1], index=0)) - n1 = dec.graph.nodes().index(DecodingGraphNode(time=0, qubits=[1, 2], index=1)) + n0 = dec.graph.nodes().index( + DecodingGraphNode(time=0, qubits=[0, 1], index=0)) + n1 = dec.graph.nodes().index( + DecodingGraphNode(time=0, qubits=[1, 2], index=1)) # edges in graph aren't directed and could be in any order if (n0, n1) in p: self.assertTrue(round(p[n0, n1], 2) == 0.33, error) @@ -224,7 +231,8 @@ def single_error_test( edges = link_graph.incident_edges(n) incident_links[node] = set() for edge in edges: - incident_links[node].add(link_graph.edges()[edge]["link qubit"]) + incident_links[node].add( + link_graph.edges()[edge]["link qubit"]) if node in code.z_logicals: incident_links[node].add(None) @@ -246,7 +254,8 @@ def single_error_test( minimal = minimal and (max(ts) - min(ts)) <= 1 # check that it doesn't extend beyond the neigbourhood of a code qubit flat_nodes = code.flatten_nodes(nodes) - link_qubits = set(node.properties["link qubit"] for node in flat_nodes) + link_qubits = set( + node.properties["link qubit"] for node in flat_nodes) minimal = minimal and link_qubits in incident_links.values() self.assertTrue( minimal, @@ -315,23 +324,28 @@ def test_202s(self): self.assertTrue( running_202 == run_202, "Error: [[2,0,2]] codes not present when required." * run_202 - + "Error: [[2,0,2]] codes present when not required." * (not run_202), + + "Error: [[2,0,2]] codes present when not required." * + (not run_202), ) # second, do they yield non-trivial outputs yet trivial nodes - code = ArcCircuit(links, T=T, run_202=True, logical="1", rounds_per_202=5) + code = ArcCircuit(links, T=T, run_202=True, + logical="1", rounds_per_202=5) backend = Aer.get_backend("aer_simulator") counts = backend.run(code.circuit[code.basis]).result().get_counts() - self.assertTrue(len(counts) > 1, "No randomness in the results for [[2,0,2]] circuits.") + self.assertTrue(len(counts) > 1, + "No randomness in the results for [[2,0,2]] circuits.") nodeless = True for string in counts: - nodeless = nodeless and code.string2nodes(string) == [] - self.assertTrue(nodeless, "Non-trivial nodes found for noiseless [[2,0,2]] circuits.") + nodeless = nodeless and not code.string2nodes(string) + self.assertTrue( + nodeless, "Non-trivial nodes found for noiseless [[2,0,2]] circuits.") def test_single_error_202s(self): """Test a range of single errors for a code with [[2,0,2]] codes.""" links = [(0, 1, 2), (2, 3, 4), (4, 5, 0), (2, 7, 6)] for T in [21, 25]: - code = ArcCircuit(links, T, run_202=True, barriers=True, logical="1", rounds_per_202=5) + code = ArcCircuit(links, T, run_202=True, + barriers=True, logical="1", rounds_per_202=5) assert code.run_202 # insert errors on a selection of qubits during a selection of rounds qc = code.circuit[code.base] @@ -348,11 +362,14 @@ def test_single_error_202s(self): barrier_num += 1 if barrier_num == 2 * t + 1: if q % 2 == 0: - error_qc.z(code.code_qubit[code.code_index[q]]) + error_qc.z( + code.code_qubit[code.code_index[q]]) else: - error_qc.x(code.link_qubit[code.link_index[q]]) + error_qc.x( + code.link_qubit[code.link_index[q]]) error_qc.append(gate) - counts = Aer.get_backend("qasm_simulator").run(error_qc).result().get_counts() + counts = Aer.get_backend("qasm_simulator").run( + error_qc).result().get_counts() for string in counts: # look at only bulk non-conjugate nodes nodes = [ @@ -397,7 +414,8 @@ def test_feedforward(self): counts = result.get_counts(j) for string in counts: # final result should be same as initial - correct_final = code.logical + str((int(code.logical) + 1) % 2) + code.logical * 2 + correct_final = code.logical + \ + str((int(code.logical) + 1) % 2) + code.logical * 2 correct = correct and string[0:4] == correct_final self.assertTrue(correct, "Result string not as required") @@ -420,7 +438,8 @@ def test_bases(self): else: for op in ["s", "sdg"]: rightops = rightops and op not in ops - self.assertTrue(rightops, "Error: Required rotations for basis changes not present.") + self.assertTrue( + rightops, "Error: Required rotations for basis changes not present.") def test_anisotropy(self): """Test that code qubits have neighbors with the opposite color.""" @@ -430,7 +449,8 @@ def test_anisotropy(self): color = code.color for j in range(1, link_num - 1): self.assertTrue( - color[2 * j] != color[2 * (j - 1)] or color[2 * j] != color[2 * (j + 1)], + color[2 * j] != color[2 * + (j - 1)] or color[2 * j] != color[2 * (j + 1)], "Error: Code qubit does not have neighbor of oppposite color.", ) @@ -439,16 +459,20 @@ def test_transpilation(self): backend = FakeJakarta() links = [(0, 1, 3), (3, 5, 6)] schedule = [[(0, 1), (3, 5)], [(3, 1), (6, 5)]] - code = ArcCircuit(links, schedule=schedule, T=2, delay=1000, logical="0") + code = ArcCircuit(links, schedule=schedule, + T=2, delay=1000, logical="0") circuit = code.transpile(backend) - self.assertTrue(code.schedule == schedule, "Error: Given schedule not used.") + self.assertTrue(code.schedule == schedule, + "Error: Given schedule not used.") circuit = code.transpile(backend, echo_num=(0, 2)) self.assertTrue( - circuit[code.base].count_ops()["x"] == 2, "Error: Wrong echo sequence for link qubits." + circuit[code.base].count_ops( + )["x"] == 2, "Error: Wrong echo sequence for link qubits." ) circuit = code.transpile(backend, echo_num=(2, 0)) self.assertTrue( - circuit[code.base].count_ops()["x"] == 8, "Error: Wrong echo sequence for code qubits." + circuit[code.base].count_ops( + )["x"] == 8, "Error: Wrong echo sequence for code qubits." ) self.assertTrue( circuit[code.base].count_ops()["cx"] == 8, @@ -526,7 +550,8 @@ def clustering_decoder_test( for row in [0, 1]: for j in range(half_d - 1): delta = row * (2 * half_d - 1) - links_ladder.append((delta + 2 * j, delta + 2 * j + 1, delta + 2 * (j + 1))) + links_ladder.append( + (delta + 2 * j, delta + 2 * j + 1, delta + 2 * (j + 1))) q = links_ladder[-1][2] + 1 for j in range(half_d): delta = 2 * half_d - 1 @@ -537,15 +562,17 @@ def clustering_decoder_test( for c, code in enumerate(codes): decoding_graph = DecodingGraph(code) if c == 3 and Decoder is UnionFindDecoder: - decoder = Decoder(code, decoding_graph=decoding_graph, use_peeling=False) + decoder = Decoder( + code, decoding_graph=decoding_graph, use_peeling=False) else: decoder = Decoder(code, decoding_graph=decoding_graph) - errors = {z_logical[0]: 0 for z_logical in decoder.measured_logicals} + errors = {z_logical[0] : 0 for z_logical in decoder.measured_logicals} min_error_num = code.d min_error_string = "" for _ in range(N): # generate random string - string = "".join([choices(["1", "0"], [1 - p, p])[0] for _ in range(d)]) + string = "".join([choices(["1", "0"], [1 - p, p])[0] + for _ in range(d)]) for _ in range(code.T): string = string + " " + "0" * (d - 1) # get and check corrected_z_logicals diff --git a/test/union_find/test_clayg.py b/test/union_find/test_clayg.py index 6147df34..765456e2 100644 --- a/test/union_find/test_clayg.py +++ b/test/union_find/test_clayg.py @@ -15,8 +15,9 @@ import math import random import unittest -from qiskit_qec.decoders import ClAYGDecoder + from qiskit_qec.circuits import RepetitionCodeCircuit +from qiskit_qec.decoders import ClAYGDecoder from qiskit_qec.noise.paulinoisemodel import PauliNoiseModel @@ -124,7 +125,8 @@ def test_error_rates(self): testcases = [] testcases = [ - "".join([random.choices(["0", "1"], [1 - p, p])[0] for _ in range(d)]) + "".join([random.choices(["0", "1"], [1 - p, p])[0] + for _ in range(d)]) for _ in range(samples) ] codes = self.construct_codes(d) @@ -144,10 +146,13 @@ def test_error_rates(self): string += testcases[sample] # get and check corrected_z_logicals outcome = decoder.process(string) - logical_outcome = sum([outcome[int(z_logical / 2)] for z_logical in z_logicals]) % 2 + # pylint: disable=consider-using-generator + logical_outcome = sum([outcome[int(z_logical / 2)] + for z_logical in z_logicals]) % 2 if not logical_outcome == 0: logical_errors += 1 - min_flips_for_logical = min(min_flips_for_logical, string.count("1")) + min_flips_for_logical = min( + min_flips_for_logical, string.count("1")) # check that error rates are at least d/3 diff --git a/test/utils/test_visualization.py b/test/utils/test_visualization.py index 7635728d..7f83d355 100644 --- a/test/utils/test_visualization.py +++ b/test/utils/test_visualization.py @@ -1,5 +1,17 @@ -# """Test pauli rep.""" +# -*- coding: utf-8 -*- +# This code is part of Qiskit. +# +# (C) Copyright IBM 2019-2023. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +"""Visualization Module""" from unittest import TestCase from qiskit_qec.utils.visualizations import QiskitGameEngine, Screen @@ -42,10 +54,12 @@ def next_frame(_): "Default color not correct for pixel in game engine.", ) self.assertTrue( - engine.screen.pixel[engine.size - 1, 0].button.button_style == "danger", + engine.screen.pixel[engine.size - 1, + 0].button.button_style == "danger", "Pixel in game engine did not turn red when required.", ) self.assertTrue( - engine.screen.pixel[0, engine.size - 1].button.description == "hello", + engine.screen.pixel[0, engine.size - + 1].button.description == "hello", "Pixel in game engine not displaying correct text", )