diff --git a/src/qiskit_qec/decoders/decoding_graph.py b/src/qiskit_qec/decoders/decoding_graph.py index f4037414..08cdc360 100644 --- a/src/qiskit_qec/decoders/decoding_graph.py +++ b/src/qiskit_qec/decoders/decoding_graph.py @@ -303,7 +303,9 @@ def get_error_probs( else: return error_probs - def weight_syndrome_graph(self, counts, method: str = METHOD_SPITZ): + def weight_syndrome_graph( + self, counts: dict = None, method: str = METHOD_SPITZ, error_probs: dict = None + ): """Generate weighted syndrome graph from result counts. Args: @@ -311,6 +313,8 @@ def weight_syndrome_graph(self, counts, method: str = METHOD_SPITZ): the weights. method (string): Method to used for calculation. Supported methods are 'spitz' (default) and 'naive'. + error_probs (dict): probability that the syndrome contains the node pair + of a given edge. Overridden by counts if both are given. Additional information: Uses `counts` to estimate the probability of the errors that @@ -318,7 +322,13 @@ def weight_syndrome_graph(self, counts, method: str = METHOD_SPITZ): replaced with the corresponding -log(p/(1-p). """ - error_probs = self.get_error_probs(counts, method=method) + if counts: + error_probs = self.get_error_probs(counts, method=method) + elif not error_probs: + raise NotImplementedError( + "No information provided to reweight the graph." + + "Specify either counts or error_probs." + ) boundary_nodes = [] for n, node in enumerate(self.graph.nodes()):