From 804d1c8fe831734e1bb8175f545241b38224ac00 Mon Sep 17 00:00:00 2001 From: Rui Coelho Date: Thu, 6 Jun 2024 14:30:51 +0100 Subject: [PATCH] Fix tests for `piglot-plot` --- piglot/bin/piglot_plot.py | 4 +++- test/test_examples_plots.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/piglot/bin/piglot_plot.py b/piglot/bin/piglot_plot.py index 2289036..02f494c 100644 --- a/piglot/bin/piglot_plot.py +++ b/piglot/bin/piglot_plot.py @@ -295,7 +295,9 @@ def plot_gp(args): x_max = max(par.ubound for par in parameters) x = torch.linspace(x_min, x_max, 1000) for name, data_dict in data.items(): - max_calls = args.max_calls if args.max_calls else len(data_dict['values']) + max_calls = len(data_dict['values']) + if args.max_calls: + max_calls = min(max_calls, args.max_calls) values = data_dict['values'][:max_calls] param_values = data_dict['params'][:max_calls] variances = data_dict['variances'][:max_calls] if 'variances' in data_dict else None diff --git a/test/test_examples_plots.py b/test/test_examples_plots.py index bc6fdb9..0830741 100644 --- a/test/test_examples_plots.py +++ b/test/test_examples_plots.py @@ -90,13 +90,13 @@ def test_input_files(input_dir: str): piglot_plot_main([ 'animation', input_file, - '--save_fig', - os.path.join(output_dir, 'animation.png'), ]) if input_file != 'test_analytical.yaml': piglot_plot_main([ 'gp', input_file, + '--max_calls', + '10', '--save_fig', os.path.join(output_dir, 'gp.png'), ])