diff --git a/nonos/main.py b/nonos/main.py index 74ceb1ea..f28f7480 100644 --- a/nonos/main.py +++ b/nonos/main.py @@ -14,7 +14,7 @@ from importlib.metadata import version from multiprocessing import Pool from pathlib import Path -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import cblind # noqa import inifix @@ -42,10 +42,40 @@ ) from nonos.styling import set_mpl_style +if TYPE_CHECKING: + from matplotlib.backend_bases import FigureCanvasBase + from matplotlib.figure import Figure + + NONOS_VERSION = version("nonos") INIFIX_GE_5_0 = Version(version("inifix")) >= Version("5.0.0") +def get_non_interactive_figure(fmt: str) -> "Figure": + from matplotlib.backends.backend_agg import FigureCanvasAgg + from matplotlib.backends.backend_pdf import FigureCanvasPdf + from matplotlib.backends.backend_ps import FigureCanvasPS + from matplotlib.backends.backend_svg import FigureCanvasSVG + from matplotlib.figure import Figure + + FigureCanvas: type[FigureCanvasBase] + + if fmt in ["png", "jpg", "jpeg", "raw", "rgba", "tif", "tiff"]: + FigureCanvas = FigureCanvasAgg + elif fmt == "pdf": + FigureCanvas = FigureCanvasPdf + elif fmt in ["ps", "eps"]: + FigureCanvas = FigureCanvasPS + elif fmt == "svg": + FigureCanvas = FigureCanvasSVG + else: + raise ValueError(f"unknown file format {fmt}") + + fig = Figure() + FigureCanvas(fig) + return fig + + # process function for parallelisation purpose with progress bar # counterParallel = Value('i', 0) # initialization of a counter def process_field( @@ -75,12 +105,6 @@ def process_field( *, log_level, ): - import matplotlib.pyplot as plt - - if not show: - # ref https://github.com/matplotlib/matplotlib/issues/28957 - plt.switch_backend("Agg") - configure_logger(level=log_level) set_mpl_style(scaling=scaling) if geometry == "unset": @@ -124,7 +148,13 @@ def process_field( # default_plane = ["x","y"] plane = default_plane - fig = plt.figure() + if show: + import matplotlib.pyplot as plt + + fig = plt.figure() + else: + fig = get_non_interactive_figure(fmt) + ax = fig.add_subplot(111, polar=False) if dim == 1: dsop.map(plane[0], rotate_with=planet_file).plot( @@ -180,14 +210,13 @@ def process_field( if show: plt.show() + plt.close(fig) else: logger.debug("saving plot: started") filename = f"{''.join(plane)}_{field}_{'_'.join(operations)}{'_diff' if diff else '_'}{'_log' if log else ''}{on:04d}.{fmt}" fig.savefig(filename, bbox_inches="tight", dpi=dpi) logger.debug("saving plot: finished ({})", filename) - plt.close(fig) - def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( diff --git a/nonos/styling.py b/nonos/styling.py index 54058b15..27efecb0 100644 --- a/nonos/styling.py +++ b/nonos/styling.py @@ -52,9 +52,9 @@ def scale_mpl(scaling: float) -> None: def set_mpl_style(scaling: float) -> None: if mpl.__version_info__ >= (3, 7): - import matplotlib.pyplot as plt + import matplotlib.style - plt.style.use("nonos.default") + matplotlib.style.use("nonos.default") else: # promise mypy this is a Path to get around a broad return type # from importlib_resource.files