From 4ce96fc0aadeaeca1896674a065fafc8562eca22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Thu, 21 Nov 2024 18:43:52 +0100 Subject: [PATCH] BUG: fix missing progress bar when running CLI on a single CPU --- nonos/main.py | 34 +++++++++++++++------------------- tests/test_plotting.py | 2 +- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/nonos/main.py b/nonos/main.py index f28f7480..610deecc 100644 --- a/nonos/main.py +++ b/nonos/main.py @@ -531,17 +531,6 @@ def main(argv: Optional[list[str]] = None) -> int: f"Requested {args['ncpu']}, but the runner only has access to {ncpu}." ) - if args["progressBar"]: - from rich.progress import track - - def mytrack(iterable, *args, **kwargs): - return track(iterable, *args, **kwargs) - - else: - # replace rich.progress.track with a no-op dummy - def mytrack(iterable, *args, **kwargs): # noqa: ARG001 - return iterable - planet_file: Optional[str] if not is_set(args["corotate"]): planet_file = None @@ -576,21 +565,28 @@ def mytrack(iterable, *args, **kwargs): # noqa: ARG001 log_level=level, ) + if args["progressBar"]: + from rich.progress import track + else: + # replace rich.progress.track with a no-op dummy + def track(it, *_args, **_kwargs): # type: ignore [misc] + return it + + progress = functools.partial( + track, + description="Processing snapshots", + total=len(args["on"]), + ) + logger.info("Starting main loop") tstart = time.time() if ncpu == 1: - for on in args["on"]: + for on in progress(args["on"]): process_field(on, **func_kwargs) else: func = functools.partial(process_field, **func_kwargs) with Pool(ncpu) as pool: - list( - mytrack( - pool.imap(func, args["on"]), - description="Processing snapshots", - total=len(args["on"]), - ) - ) + list(progress(pool.imap(func, args["on"]))) if not show: logger.info("Operation took {:.2f}s", time.time() - tstart) diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 0f76260b..d65eec3d 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -175,7 +175,7 @@ def test_pbar(simulation_dir, capsys, tmp_path): out, err = capsys.readouterr() assert err == "" - assert out == "" + assert "Processing snapshots" in out assert ret == 0