Skip to content

Commit

Permalink
RFC: avoid pyplot interface in lib
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros committed Nov 2, 2024
1 parent c22316c commit 11624b6
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 12 deletions.
49 changes: 39 additions & 10 deletions nonos/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from importlib.metadata import version
from multiprocessing import Pool
from pathlib import Path
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional

import cblind # noqa
import inifix
Expand Down Expand Up @@ -42,10 +42,40 @@
)
from nonos.styling import set_mpl_style

if TYPE_CHECKING:
from matplotlib.backend_bases import FigureCanvasBase
from matplotlib.figure import Figure


NONOS_VERSION = version("nonos")
INIFIX_GE_5_0 = Version(version("inifix")) >= Version("5.0.0")


def get_non_interactive_figure(fmt: str) -> "Figure":
from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.backends.backend_pdf import FigureCanvasPdf
from matplotlib.backends.backend_ps import FigureCanvasPS
from matplotlib.backends.backend_svg import FigureCanvasSVG
from matplotlib.figure import Figure

FigureCanvas: type[FigureCanvasBase]

if fmt in ["png", "jpg", "jpeg", "raw", "rgba", "tif", "tiff"]:
FigureCanvas = FigureCanvasAgg
elif fmt == "pdf":
FigureCanvas = FigureCanvasPdf
elif fmt in ["ps", "eps"]:
FigureCanvas = FigureCanvasPS
elif fmt == "svg":
FigureCanvas = FigureCanvasSVG
else:
raise ValueError(f"unknown file format {fmt}")

fig = Figure()
FigureCanvas(fig)
return fig


# process function for parallelisation purpose with progress bar
# counterParallel = Value('i', 0) # initialization of a counter
def process_field(
Expand Down Expand Up @@ -75,12 +105,6 @@ def process_field(
*,
log_level,
):
import matplotlib.pyplot as plt

if not show:
# ref https://github.com/matplotlib/matplotlib/issues/28957
plt.switch_backend("Agg")

configure_logger(level=log_level)
set_mpl_style(scaling=scaling)
if geometry == "unset":
Expand Down Expand Up @@ -124,7 +148,13 @@ def process_field(
# default_plane = ["x","y"]
plane = default_plane

fig = plt.figure()
if show:
import matplotlib.pyplot as plt

fig = plt.figure()
else:
fig = get_non_interactive_figure(fmt)

ax = fig.add_subplot(111, polar=False)
if dim == 1:
dsop.map(plane[0], rotate_with=planet_file).plot(
Expand Down Expand Up @@ -180,14 +210,13 @@ def process_field(

if show:
plt.show()
plt.close(fig)
else:
logger.debug("saving plot: started")
filename = f"{''.join(plane)}_{field}_{'_'.join(operations)}{'_diff' if diff else '_'}{'_log' if log else ''}{on:04d}.{fmt}"
fig.savefig(filename, bbox_inches="tight", dpi=dpi)
logger.debug("saving plot: finished ({})", filename)

plt.close(fig)


def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
Expand Down
4 changes: 2 additions & 2 deletions nonos/styling.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def scale_mpl(scaling: float) -> None:

def set_mpl_style(scaling: float) -> None:
if mpl.__version_info__ >= (3, 7):
import matplotlib.pyplot as plt
import matplotlib.style

plt.style.use("nonos.default")
matplotlib.style.use("nonos.default")
else:
# promise mypy this is a Path to get around a broad return type
# from importlib_resource.files
Expand Down

0 comments on commit 11624b6

Please sign in to comment.