From 46ff0728958587934e9f607fac39c6ebe9f1e5e0 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Thu, 27 Jul 2023 13:35:19 +0200 Subject: [PATCH] Update correlator tests Lots of minor fixes to the tests, and fixed one bug in the correlator that caused an extra nan value to be added to the error arrays for each series. --- extra_foam/special_suite/correlator_w.py | 5 ++--- .../special_suite/tests/test_correlator.py | 20 +++++++++---------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/extra_foam/special_suite/correlator_w.py b/extra_foam/special_suite/correlator_w.py index 2d8cc1e2..cf913dd7 100644 --- a/extra_foam/special_suite/correlator_w.py +++ b/extra_foam/special_suite/correlator_w.py @@ -663,15 +663,14 @@ def handle_rich_output(): for i, value in enumerate(data.values[1:]): label = y_series_labels[i] - self._xs[label].append(data.values[0]) - self.update_scalar_series(label, value) - # TODO: move this block of code into update_scalar_series() if label in series_errors and not self.xbinning_enabled: if label not in self._errors and len(self._xs[label]) > 0: self._errors[label].extend(np.full((len(self._xs[label]), ), np.nan)) self._errors[label].append(series_errors[label]) + self._xs[label].append(data.values[0]) + self.update_scalar_series(label, value) # If we're dealing with Vector data elif data.ndim == 2: diff --git a/extra_foam/special_suite/tests/test_correlator.py b/extra_foam/special_suite/tests/test_correlator.py index 647146d8..bf7740e3 100644 --- a/extra_foam/special_suite/tests/test_correlator.py +++ b/extra_foam/special_suite/tests/test_correlator.py @@ -1,6 +1,7 @@ import time import tempfile import textwrap +from pathlib import Path from unittest.mock import patch, ANY import pytest @@ -72,7 +73,7 @@ def foo(tid: "internal#train_id"): """ ctx = textwrap.dedent(ctx) - with tempfile.NamedTemporaryFile() as ctx_file: + with tempfile.NamedTemporaryFile(suffix=".py") as ctx_file: # Save the context to a file ctx_file.write(ctx.encode()) ctx_file.flush() @@ -80,12 +81,11 @@ def foo(tid: "internal#train_id"): # Helper function to read the current contents of the context file def saved_ctx(): - ctx_file.seek(0) - return ctx_file.read().decode() + return Path(path).read_text() with patch.object(QFileDialog, "getOpenFileName", return_value=(path, )): # Open the context file - win._openContext() + win._ctrl_widget_st.onOpenFile() # Check the path and source is displayed correctly assert win._ctrl_widget_st._path_label.text() == path @@ -284,7 +284,7 @@ def testUberSplitter(self, win): def testViewWidget(self, win, initial_context): widget = win._tab_widget.widget(1).widget(0) - plot_widget = widget._plot_widget + plot_widget_splitter = widget._plot_widget_splitter view_picker = widget.view_picker view_picker_widget = view_picker.parent() assert type(widget) == ViewWidget @@ -298,11 +298,11 @@ def testViewWidget(self, win, initial_context): # Selecting the image view should show the image view widget, everything # else should show the plot widget. view_picker.setCurrentText("view#compute") - assert widget.currentWidget() == plot_widget, "Wrong widget displayed for View.Compute" + assert widget.currentWidget() == plot_widget_splitter, "Wrong widget displayed for View.Compute" view_picker.setCurrentText("view#scalar") - assert widget.currentWidget() == plot_widget, "Wrong widget displayed for View.Scalar" + assert widget.currentWidget() == plot_widget_splitter, "Wrong widget displayed for View.Scalar" view_picker.setCurrentText("view#vector") - assert widget.currentWidget() == plot_widget, "Wrong widget displayed for View.Vector" + assert widget.currentWidget() == plot_widget_splitter, "Wrong widget displayed for View.Vector" view_picker.setCurrentText("view#image") assert widget.currentWidget() == widget._image_widget, "Wrong widget displayed for View.Image" @@ -365,7 +365,7 @@ def test1dPlotting(self, view_type, output_data, win, initial_context, caplog): setTitle.assert_called_with("Baz") # There should be 'max_points' points - assert len(widget._xs) == max_points + assert len(widget._xs["y0"]) == max_points assert len(widget._ys["y0"]) == max_points widget.reset() @@ -383,7 +383,7 @@ def test1dPlotting(self, view_type, output_data, win, initial_context, caplog): # rather than how many trains in total have been processed. output_len = len(output_data[0]) if is_vector else len(output_data) - assert len(widget._xs) == output_len + assert len(widget._xs["Foo"]) == output_len for series in ["Foo", "Bar"]: assert len(widget._ys[series]) == output_len assert len(widget._errors[series]) == output_len