Skip to content

Commit

Permalink
Update correlator tests
Browse files Browse the repository at this point in the history
Lots of minor fixes to the tests, and fixed one bug in the correlator that
caused an extra nan value to be added to the error arrays for each series.
  • Loading branch information
JamesWrigley committed Jul 27, 2023
1 parent d1f09da commit 46ff072
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
5 changes: 2 additions & 3 deletions extra_foam/special_suite/correlator_w.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,15 +663,14 @@ def handle_rich_output():
for i, value in enumerate(data.values[1:]):
label = y_series_labels[i]

self._xs[label].append(data.values[0])
self.update_scalar_series(label, value)

# TODO: move this block of code into update_scalar_series()
if label in series_errors and not self.xbinning_enabled:
if label not in self._errors and len(self._xs[label]) > 0:
self._errors[label].extend(np.full((len(self._xs[label]), ), np.nan))
self._errors[label].append(series_errors[label])

self._xs[label].append(data.values[0])
self.update_scalar_series(label, value)

# If we're dealing with Vector data
elif data.ndim == 2:
Expand Down
20 changes: 10 additions & 10 deletions extra_foam/special_suite/tests/test_correlator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import time
import tempfile
import textwrap
from pathlib import Path
from unittest.mock import patch, ANY

import pytest
Expand Down Expand Up @@ -72,20 +73,19 @@ def foo(tid: "internal#train_id"):
"""
ctx = textwrap.dedent(ctx)

with tempfile.NamedTemporaryFile() as ctx_file:
with tempfile.NamedTemporaryFile(suffix=".py") as ctx_file:
# Save the context to a file
ctx_file.write(ctx.encode())
ctx_file.flush()
path = ctx_file.name

# Helper function to read the current contents of the context file
def saved_ctx():
ctx_file.seek(0)
return ctx_file.read().decode()
return Path(path).read_text()

with patch.object(QFileDialog, "getOpenFileName", return_value=(path, )):
# Open the context file
win._openContext()
win._ctrl_widget_st.onOpenFile()

# Check the path and source is displayed correctly
assert win._ctrl_widget_st._path_label.text() == path
Expand Down Expand Up @@ -284,7 +284,7 @@ def testUberSplitter(self, win):

def testViewWidget(self, win, initial_context):
widget = win._tab_widget.widget(1).widget(0)
plot_widget = widget._plot_widget
plot_widget_splitter = widget._plot_widget_splitter
view_picker = widget.view_picker
view_picker_widget = view_picker.parent()
assert type(widget) == ViewWidget
Expand All @@ -298,11 +298,11 @@ def testViewWidget(self, win, initial_context):
# Selecting the image view should show the image view widget, everything
# else should show the plot widget.
view_picker.setCurrentText("view#compute")
assert widget.currentWidget() == plot_widget, "Wrong widget displayed for View.Compute"
assert widget.currentWidget() == plot_widget_splitter, "Wrong widget displayed for View.Compute"
view_picker.setCurrentText("view#scalar")
assert widget.currentWidget() == plot_widget, "Wrong widget displayed for View.Scalar"
assert widget.currentWidget() == plot_widget_splitter, "Wrong widget displayed for View.Scalar"
view_picker.setCurrentText("view#vector")
assert widget.currentWidget() == plot_widget, "Wrong widget displayed for View.Vector"
assert widget.currentWidget() == plot_widget_splitter, "Wrong widget displayed for View.Vector"
view_picker.setCurrentText("view#image")
assert widget.currentWidget() == widget._image_widget, "Wrong widget displayed for View.Image"

Expand Down Expand Up @@ -365,7 +365,7 @@ def test1dPlotting(self, view_type, output_data, win, initial_context, caplog):
setTitle.assert_called_with("Baz")

# There should be 'max_points' points
assert len(widget._xs) == max_points
assert len(widget._xs["y0"]) == max_points
assert len(widget._ys["y0"]) == max_points

widget.reset()
Expand All @@ -383,7 +383,7 @@ def test1dPlotting(self, view_type, output_data, win, initial_context, caplog):
# rather than how many trains in total have been processed.
output_len = len(output_data[0]) if is_vector else len(output_data)

assert len(widget._xs) == output_len
assert len(widget._xs["Foo"]) == output_len
for series in ["Foo", "Bar"]:
assert len(widget._ys[series]) == output_len
assert len(widget._errors[series]) == output_len
Expand Down

0 comments on commit 46ff072

Please sign in to comment.