diff --git a/qiskit_addon_obp/utils/visualization.py b/qiskit_addon_obp/utils/visualization.py index 745891f..040b4ee 100644 --- a/qiskit_addon_obp/utils/visualization.py +++ b/qiskit_addon_obp/utils/visualization.py @@ -36,7 +36,7 @@ from .metadata import OBPMetadata -def plot_accumulated_error(metadata: OBPMetadata, axes: Axes) -> None: +def plot_accumulated_error(metadata: OBPMetadata, axes: Axes, show_legend: bool = True) -> None: """Plot the accumulated error. This method populates the provided figure axes with a line-plot of the @@ -72,6 +72,7 @@ def plot_accumulated_error(metadata: OBPMetadata, axes: Axes) -> None: Args: metadata: the metadata to be visualized. axes: the matplotlib axes in which to plot. + show_legend: enable/disable showing the legend in the plot. """ if not np.isinf(metadata.truncation_error_budget.max_error_total): axes.axhline( @@ -93,10 +94,12 @@ def plot_accumulated_error(metadata: OBPMetadata, axes: Axes) -> None: ) axes.set_xlabel("backpropagated slice number") axes.set_ylabel("accumulated error") - axes.legend() + _set_legend(axes, show_legend) -def plot_left_over_error_budget(metadata: OBPMetadata, axes: Axes) -> None: +def plot_left_over_error_budget( + metadata: OBPMetadata, axes: Axes, show_legend: bool = True +) -> None: """Plot the left-over error budget. This method populates the provided figure axes with a line-plot of the @@ -127,6 +130,7 @@ def plot_left_over_error_budget(metadata: OBPMetadata, axes: Axes) -> None: Args: metadata: the metadata to be visualized. axes: the matplotlib axes in which to plot. + show_legend: enable/disable showing the legend in the plot. """ for obs_idx in range(len(metadata.backpropagation_history[0].slice_errors)): axes.plot( @@ -139,10 +143,10 @@ def plot_left_over_error_budget(metadata: OBPMetadata, axes: Axes) -> None: ) axes.set_xlabel("backpropagated slice number") axes.set_ylabel("left-over error budget") - axes.legend() + _set_legend(axes, show_legend) -def plot_slice_errors(metadata: OBPMetadata, axes: Axes) -> None: +def plot_slice_errors(metadata: OBPMetadata, axes: Axes, show_legend: bool = True) -> None: """Plot the slice errors. This method populates the provided figure axes with a bar-plot of the truncation error incurred @@ -176,6 +180,7 @@ def plot_slice_errors(metadata: OBPMetadata, axes: Axes) -> None: Args: metadata: the metadata to be visualized. axes: the matplotlib axes in which to plot. + show_legend: enable/disable showing the legend in the plot. """ num_observables = len(metadata.backpropagation_history[0].slice_errors) width = 0.8 / num_observables @@ -193,9 +198,10 @@ def plot_slice_errors(metadata: OBPMetadata, axes: Axes) -> None: axes.set_xlabel("backpropagated slice number") axes.set_ylabel("incurred slice error") axes.legend() + _set_legend(axes, show_legend) -def plot_num_paulis(metadata: OBPMetadata, axes: Axes) -> None: +def plot_num_paulis(metadata: OBPMetadata, axes: Axes, show_legend: bool = True) -> None: """Plot the number of Pauli terms. This method populates the provided figure axes with a line-plot of the number of Pauli terms at @@ -229,6 +235,7 @@ def plot_num_paulis(metadata: OBPMetadata, axes: Axes) -> None: Args: metadata: the metadata to be visualized. axes: the matplotlib axes in which to plot. + show_legend: enable/disable showing the legend in the plot. """ for obs_idx in range(len(metadata.backpropagation_history[0].slice_errors)): axes.plot( @@ -238,10 +245,10 @@ def plot_num_paulis(metadata: OBPMetadata, axes: Axes) -> None: ) axes.set_xlabel("backpropagated slice number") axes.set_ylabel("# Pauli terms") - axes.legend() + _set_legend(axes, show_legend) -def plot_num_truncated_paulis(metadata: OBPMetadata, axes: Axes) -> None: +def plot_num_truncated_paulis(metadata: OBPMetadata, axes: Axes, show_legend: bool = True) -> None: """Plot the number of truncated Pauli terms. This method populates the provided figure axes with a bar-plot of the number of the truncated @@ -275,6 +282,7 @@ def plot_num_truncated_paulis(metadata: OBPMetadata, axes: Axes) -> None: Args: metadata: the metadata to be visualized. axes: the matplotlib axes in which to plot. + show_legend: enable/disable showing the legend in the plot. """ num_observables = len(metadata.backpropagation_history[0].slice_errors) width = 0.8 / num_observables @@ -291,10 +299,10 @@ def plot_num_truncated_paulis(metadata: OBPMetadata, axes: Axes) -> None: offset += width axes.set_xlabel("backpropagated slice number") axes.set_ylabel("# truncated Pauli terms") - axes.legend() + _set_legend(axes, show_legend) -def plot_sum_paulis(metadata: OBPMetadata, axes: Axes) -> None: +def plot_sum_paulis(metadata: OBPMetadata, axes: Axes, show_legend: bool = True) -> None: """Plot the total number of all Pauli terms. This method populates the provided figure axes with a line-plot of the total number of all Pauli @@ -329,6 +337,7 @@ def plot_sum_paulis(metadata: OBPMetadata, axes: Axes) -> None: Args: metadata: the metadata to be visualized. axes: the matplotlib axes in which to plot. + show_legend: enable/disable showing the legend in the plot. """ if metadata.operator_budget.max_paulis is not None: axes.axhline( @@ -346,10 +355,10 @@ def plot_sum_paulis(metadata: OBPMetadata, axes: Axes) -> None: ) axes.set_xlabel("backpropagated slice number") axes.set_ylabel("total # of Pauli terms") - axes.legend() + _set_legend(axes, show_legend) -def plot_num_qwc_groups(metadata: OBPMetadata, axes: Axes) -> None: +def plot_num_qwc_groups(metadata: OBPMetadata, axes: Axes, show_legend: bool = True) -> None: """Plot the number of qubit-wise commuting Pauli groups. This method populates the provided figure axes with a line-plot of the number of qubit-wise @@ -380,6 +389,7 @@ def plot_num_qwc_groups(metadata: OBPMetadata, axes: Axes) -> None: Args: metadata: the metadata to be visualized. axes: the matplotlib axes in which to plot. + show_legend: enable/disable showing the legend in the plot. """ if metadata.operator_budget.max_qwc_groups is not None: axes.axhline( @@ -397,4 +407,9 @@ def plot_num_qwc_groups(metadata: OBPMetadata, axes: Axes) -> None: ) axes.set_xlabel("backpropagated slice number") axes.set_ylabel("# of qubit-wise commuting Pauli groups") - axes.legend() + _set_legend(axes, show_legend) + + +def _set_legend(axes: Axes, show_legend: bool) -> None: + if show_legend: + axes.legend()