diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index e050b8c9..490abd83 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -3,36 +3,45 @@ name: Coverage on: pull_request: paths-ignore: - - 'doc/**' - - '.ci/**' - - '*.rst' + - "doc/**" + - ".ci/**" + - "*.rst" push: branches: - main - develop - beta/* paths-ignore: - - 'doc/**' - - '.ci/**' - - '*.rst' + - "doc/**" + - ".ci/**" + - "*.rst" jobs: coverage: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - with: - submodules: true - - uses: hendrikmuhs/ccache-action@v1.2 - with: - key: ${{ github.job }}-${{ matrix.os }}-${{ matrix.python-version }} - create-symlink: true - - uses: rui314/setup-mold@v1 - - uses: actions/setup-python@v5 - with: - python-version: "3.12" - - uses: astral-sh/setup-uv@v4 - - run: uv pip install --system nox - - run: nox -s cov - - uses: AndreMiras/coveralls-python-action@develop + # install-qt-action also does setup-python + - name: Install Qt + uses: jurplel/install-qt-action@v3 + with: + aqtversion: "==3.1.*" + version: "6.8.1" + host: "linux" + target: "desktop" + arch: "linux_gcc_64" + # - uses: actions/setup-python@v5 + # with: + # python-version: "3.12" + - uses: actions/checkout@v4 + with: + submodules: true + - uses: hendrikmuhs/ccache-action@v1.2 + with: + key: ${{ github.job }}-${{ matrix.os }}-${{ matrix.python-version }} + create-symlink: true + - uses: rui314/setup-mold@v1 + - uses: astral-sh/setup-uv@v4 + - run: uv pip install --system nox + - run: nox -s cov + - uses: AndreMiras/coveralls-python-action@develop diff --git a/noxfile.py b/noxfile.py index 3710a3f9..f6f99b27 100644 --- a/noxfile.py +++ b/noxfile.py @@ -67,7 +67,7 @@ def pypy(session: nox.Session) -> None: # Python-3.12 provides coverage info faster -@nox.session(python="3.12", venv_backend="uv", reuse_venv=True) +@nox.session(venv_backend="uv", reuse_venv=True) def cov(session: nox.Session) -> None: """Run covage and place in 'htmlcov' directory.""" session.install("--only-binary=:all:", "-e.[test,doc]") diff --git a/pyproject.toml b/pyproject.toml index 7b8aa6d9..6330a63f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ test = [ "ipywidgets", # needed by ipywidgets >= 8.0.6 "ipykernel", + "PyQt6", "joblib", "jacobi", "matplotlib", @@ -52,6 +53,7 @@ test = [ "numba-stats; platform_python_implementation=='CPython'", "pytest", "pytest-xdist", + "pytest-qt", "scipy", "tabulate", "boost_histogram", @@ -101,6 +103,7 @@ pydocstyle.convention = "numpy" [tool.ruff.lint.per-file-ignores] "test_*.py" = ["B", "D"] +"conftest.py" = ["B", "D"] "*.ipynb" = ["D"] "automatic_differentiation.ipynb" = ["F821"] "cython.ipynb" = ["F821"] diff --git a/src/iminuit/ipywidget.py b/src/iminuit/ipywidget.py index f56bdeef..4e531b3e 100644 --- a/src/iminuit/ipywidget.py +++ b/src/iminuit/ipywidget.py @@ -1,9 +1,9 @@ """Interactive fitting widget for Jupyter notebooks.""" +from .util import _widget_guess_initial_step, _make_finite import warnings import numpy as np from typing import Dict, Any, Callable -import sys with warnings.catch_warnings(): # ipywidgets produces deprecation warnings through use of internal APIs :( @@ -148,7 +148,7 @@ class Parameter(widgets.HBox): def __init__(self, minuit, par): val = minuit.values[par] vmin, vmax = minuit.limits[par] - step = _guess_initial_step(val, vmin, vmax) + step = _widget_guess_initial_step(val, vmin, vmax) vmin2 = vmin if np.isfinite(vmin) else val - 100 * step vmax2 = vmax if np.isfinite(vmax) else val + 100 * step @@ -277,18 +277,5 @@ def reset(self, value, limits=None): return widgets.HBox([out, ui]) -def _make_finite(x: float) -> float: - sign = -1 if x < 0 else 1 - if abs(x) == np.inf: - return sign * sys.float_info.max - return x - - -def _guess_initial_step(val: float, vmin: float, vmax: float) -> float: - if np.isfinite(vmin) and np.isfinite(vmax): - return 1e-2 * (vmax - vmin) - return 1e-2 - - def _round(x: float) -> float: return float(f"{x:.1g}") diff --git a/src/iminuit/minuit.py b/src/iminuit/minuit.py index 399d5525..ee63b064 100644 --- a/src/iminuit/minuit.py +++ b/src/iminuit/minuit.py @@ -2341,10 +2341,14 @@ def interactive( **kwargs, ): """ - Return fitting widget (requires ipywidgets, IPython, matplotlib). + Interactive GUI for fitting. - A fitting widget is returned which can be displayed and manipulated in a - Jupyter notebook to find good starting parameters and to debug the fit. + Starts a fitting application (requires PyQt6, matplotlib) in which the + fit is visualized and the parameters can be manipulated to find good + starting parameters and to debug the fit. + + When called in a Jupyter notebook (requires ipywidgets, IPython, matplotlib), + a fitting widget is returned instead, which can be displayed. Parameters ---------- @@ -2371,9 +2375,14 @@ def interactive( -------- Minuit.visualize """ - from iminuit.ipywidget import make_widget - plot = self._visualize(plot) + + if mutil.is_jupyter(): + from iminuit.ipywidget import make_widget + + else: + from iminuit.qtwidget import make_widget + return make_widget(self, plot, kwargs, raise_on_exception) def _free_parameters(self) -> Set[str]: diff --git a/src/iminuit/qtwidget.py b/src/iminuit/qtwidget.py new file mode 100644 index 00000000..3d0f3d86 --- /dev/null +++ b/src/iminuit/qtwidget.py @@ -0,0 +1,381 @@ +"""Interactive fitting widget using PyQt6.""" + +from .util import _widget_guess_initial_step, _make_finite +import warnings +import numpy as np +from typing import Dict, Any, Callable +from contextlib import contextmanager + +try: + from PyQt6 import QtCore, QtGui, QtWidgets + from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg + from matplotlib import pyplot as plt +except ModuleNotFoundError as e: + e.msg += ( + "\n\nPlease install PyQt6, and matplotlib to enable interactive " + "outside of Jupyter notebooks." + ) + raise + + +def make_widget( + minuit: Any, + plot: Callable[..., None], + kwargs: Dict[str, Any], + raise_on_exception: bool, + run_event_loop: bool = True, +): + """Make interactive fitting widget.""" + original_values = minuit.values[:] + original_limits = minuit.limits[:] + + class Parameter(QtWidgets.QGroupBox): + def __init__(self, minuit, par, callback): + super().__init__("") + self.par = par + self.callback = callback + + size_policy = QtWidgets.QSizePolicy( + QtWidgets.QSizePolicy.Policy.MinimumExpanding, + QtWidgets.QSizePolicy.Policy.Fixed, + ) + self.setSizePolicy(size_policy) + layout = QtWidgets.QVBoxLayout() + self.setLayout(layout) + + label = QtWidgets.QLabel(par, alignment=QtCore.Qt.AlignmentFlag.AlignCenter) + label.setMinimumSize(QtCore.QSize(50, 0)) + self.value_label = QtWidgets.QLabel( + alignment=QtCore.Qt.AlignmentFlag.AlignCenter + ) + self.value_label.setMinimumSize(QtCore.QSize(50, 0)) + self.slider = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal) + self.slider.setMinimum(0) + self.slider.setMaximum(int(1e8)) + self.tmin = QtWidgets.QDoubleSpinBox( + alignment=QtCore.Qt.AlignmentFlag.AlignCenter + ) + self.tmin.setRange(_make_finite(-np.inf), _make_finite(np.inf)) + self.tmax = QtWidgets.QDoubleSpinBox( + alignment=QtCore.Qt.AlignmentFlag.AlignCenter + ) + self.tmax.setRange(_make_finite(-np.inf), _make_finite(np.inf)) + self.tmin.setSizePolicy(size_policy) + self.tmax.setSizePolicy(size_policy) + self.fix = QtWidgets.QPushButton("Fix") + self.fix.setCheckable(True) + self.fix.setChecked(minuit.fixed[par]) + self.fit = QtWidgets.QPushButton("Fit") + self.fit.setCheckable(True) + self.fit.setChecked(False) + size_policy = QtWidgets.QSizePolicy( + QtWidgets.QSizePolicy.Policy.Fixed, QtWidgets.QSizePolicy.Policy.Fixed + ) + self.fix.setSizePolicy(size_policy) + self.fit.setSizePolicy(size_policy) + layout1 = QtWidgets.QHBoxLayout() + layout.addLayout(layout1) + layout1.addWidget(label) + layout1.addWidget(self.slider) + layout1.addWidget(self.value_label) + layout1.addWidget(self.fix) + layout2 = QtWidgets.QHBoxLayout() + layout.addLayout(layout2) + layout2.addWidget(self.tmin) + layout2.addWidget(self.tmax) + layout2.addWidget(self.fit) + + self.reset(minuit.values[par], limits=minuit.limits[par]) + + step_size = 1e-1 * (self.vmax - self.vmin) + decimals = max(int(-np.log10(step_size)) + 2, 0) + self.tmin.setSingleStep(step_size) + self.tmin.setDecimals(decimals) + self.tmax.setSingleStep(step_size) + self.tmax.setDecimals(decimals) + self.tmin.setMinimum(_make_finite(minuit.limits[par][0])) + self.tmax.setMaximum(_make_finite(minuit.limits[par][1])) + + self.slider.valueChanged.connect(self.on_val_changed) + self.fix.clicked.connect(self.on_fix_toggled) + self.tmin.valueChanged.connect(self.on_min_changed) + self.tmax.valueChanged.connect(self.on_max_changed) + self.fit.clicked.connect(self.on_fit_toggled) + + def _int_to_float(self, value): + return self.vmin + (value / 1e8) * (self.vmax - self.vmin) + + def _float_to_int(self, value): + return int((value - self.vmin) / (self.vmax - self.vmin) * 1e8) + + def on_val_changed(self, val): + val = self._int_to_float(val) + self.value_label.setText(f"{val:.3g}") + minuit.values[self.par] = val + self.callback() + + def on_min_changed(self): + tmin = self.tmin.value() + if tmin >= self.vmax: + with _block_signals(self.tmin): + self.tmin.setValue(self.vmin) + return + self.vmin = tmin + with _block_signals(self.slider): + if tmin > self.val: + self.val = tmin + minuit.values[self.par] = tmin + self.slider.setValue(0) + self.value_label.setText(f"{self.val:.3g}") + self.callback() + else: + self.slider.setValue(self._float_to_int(self.val)) + lim = minuit.limits[self.par] + minuit.limits[self.par] = (tmin, lim[1]) + + def on_max_changed(self): + tmax = self.tmax.value() + if tmax <= self.tmin.value(): + with _block_signals(self.tmax): + self.tmax.setValue(self.vmax) + return + self.vmax = tmax + with _block_signals(self.slider): + if tmax < self.val: + self.val = tmax + minuit.values[self.par] = tmax + self.slider.setValue(int(1e8)) + self.value_label.setText(f"{self.val:.3g}") + self.callback() + else: + self.slider.setValue(self._float_to_int(self.val)) + lim = minuit.limits[self.par] + minuit.limits[self.par] = (lim[0], tmax) + + def on_fix_toggled(self): + minuit.fixed[self.par] = self.fix.isChecked() + if self.fix.isChecked(): + self.fit.setChecked(False) + + def on_fit_toggled(self): + self.slider.setEnabled(not self.fit.isChecked()) + if self.fit.isChecked(): + self.fix.setChecked(False) + self.callback() + + def reset(self, val, limits=None): + if limits is not None: + vmin, vmax = limits + step = _widget_guess_initial_step(val, vmin, vmax) + self.vmin = vmin if np.isfinite(vmin) else val - 100 * step + self.vmax = vmax if np.isfinite(vmax) else val + 100 * step + with _block_signals(self.tmin, self.tmax): + self.tmin.setValue(self.vmin) + self.tmax.setValue(self.vmax) + + self.val = val + if self.val < self.vmin: + self.vmin = self.val + with _block_signals(self.tmin): + self.tmin.setValue(self.vmin) + elif self.val > self.vmax: + self.vmax = self.val + with _block_signals(self.tmax): + self.tmax.setValue(self.vmax) + + with _block_signals(self.slider): + self.slider.setValue(self._float_to_int(self.val)) + self.value_label.setText(f"{self.val:.3g}") + + class Widget(QtWidgets.QWidget): + def __init__(self): + super().__init__() + self.resize(1280, 720) + font = QtGui.QFont() + font.setPointSize(12) + self.setFont(font) + self.setWindowTitle("iminuit") + + interactive_layout = QtWidgets.QGridLayout(self) + + plot_group = QtWidgets.QGroupBox("", parent=self) + size_policy = QtWidgets.QSizePolicy( + QtWidgets.QSizePolicy.Policy.MinimumExpanding, + QtWidgets.QSizePolicy.Policy.MinimumExpanding, + ) + plot_group.setSizePolicy(size_policy) + plot_layout = QtWidgets.QVBoxLayout(plot_group) + fig = plt.figure() + manager = plt.get_current_fig_manager() + self.canvas = FigureCanvasQTAgg(fig) + self.canvas.manager = manager + plot_layout.addWidget(self.canvas) + interactive_layout.addWidget(plot_group, 0, 0, 2, 1) + + button_group = QtWidgets.QGroupBox("", parent=self) + size_policy = QtWidgets.QSizePolicy( + QtWidgets.QSizePolicy.Policy.Expanding, + QtWidgets.QSizePolicy.Policy.Fixed, + ) + button_group.setSizePolicy(size_policy) + button_group.setMaximumWidth(500) + button_layout = QtWidgets.QHBoxLayout(button_group) + self.fit_button = QtWidgets.QPushButton("Fit", parent=button_group) + self.fit_button.setStyleSheet("background-color: #2196F3; color: white") + self.fit_button.clicked.connect(lambda: self.do_fit(plot=True)) + button_layout.addWidget(self.fit_button) + self.update_button = QtWidgets.QPushButton( + "Continuous", parent=button_group + ) + self.update_button.setCheckable(True) + self.update_button.setChecked(True) + self.update_button.clicked.connect(self.on_update_button_clicked) + button_layout.addWidget(self.update_button) + self.reset_button = QtWidgets.QPushButton("Reset", parent=button_group) + self.reset_button.setStyleSheet("background-color: #F44336; color: white") + self.reset_button.clicked.connect(self.on_reset_button_clicked) + button_layout.addWidget(self.reset_button) + self.algo_choice = QtWidgets.QComboBox(parent=button_group) + self.algo_choice.setEditable(True) + self.algo_choice.lineEdit().setAlignment( + QtCore.Qt.AlignmentFlag.AlignCenter + ) + self.algo_choice.lineEdit().setReadOnly(True) + self.algo_choice.addItems(["Migrad", "Scipy", "Simplex"]) + button_layout.addWidget(self.algo_choice) + interactive_layout.addWidget(button_group, 0, 1, 1, 1) + + par_scroll_area = QtWidgets.QScrollArea() + par_scroll_area.setWidgetResizable(True) + size_policy = QtWidgets.QSizePolicy( + QtWidgets.QSizePolicy.Policy.MinimumExpanding, + QtWidgets.QSizePolicy.Policy.MinimumExpanding, + ) + par_scroll_area.setSizePolicy(size_policy) + par_scroll_area.setMaximumWidth(500) + scroll_area_contents = QtWidgets.QWidget() + parameter_layout = QtWidgets.QVBoxLayout(scroll_area_contents) + par_scroll_area.setWidget(scroll_area_contents) + interactive_layout.addWidget(par_scroll_area, 1, 1, 2, 1) + self.parameters = [] + for par in minuit.parameters: + parameter = Parameter(minuit, par, self.on_parameter_change) + self.parameters.append(parameter) + parameter_layout.addWidget(parameter) + parameter_layout.addStretch() + + self.results_text = QtWidgets.QTextEdit(parent=self) + self.results_text.setReadOnly(True) + self.results_text.setSizePolicy(size_policy) + self.results_text.setMaximumHeight(144) + interactive_layout.addWidget(self.results_text, 2, 0, 1, 1) + + self.plot_with_frame(from_fit=False, report_success=False) + + def plot_with_frame(self, from_fit, report_success): + trans = plt.gca().transAxes + try: + with warnings.catch_warnings(): + fig_size = plt.gcf().get_size_inches() + minuit.visualize(plot, **kwargs) + plt.gcf().set_size_inches(fig_size) + except Exception: + if raise_on_exception: + raise + + import traceback + + plt.figtext( + 0, + 0.5, + traceback.format_exc(limit=-1), + fontdict={"family": "monospace", "size": "x-small"}, + va="center", + color="r", + backgroundcolor="w", + wrap=True, + ) + return + + fval = minuit.fmin.fval if from_fit else minuit._fcn(minuit.values) + plt.text( + 0.05, + 1.05, + f"FCN = {fval:.3f}", + transform=trans, + fontsize="x-large", + ) + if from_fit and report_success: + self.results_text.clear() + self.results_text.setHtml( + f"
{minuit.fmin._repr_html_()}
" + ) + else: + self.results_text.clear() + + def fit(self): + if self.algo_choice.currentText() == "Migrad": + minuit.migrad() + elif self.algo_choice.currentText() == "Scipy": + minuit.scipy() + elif self.algo_choice.currentText() == "Simplex": + minuit.simplex() + return False + else: + assert False # pragma: no cover, should never happen + return True + + def on_parameter_change(self, from_fit=False, report_success=False): + if any(x.fit.isChecked() for x in self.parameters): + saved = minuit.fixed[:] + for i, x in enumerate(self.parameters): + minuit.fixed[i] = not x.fit.isChecked() + from_fit = True + report_success = self.do_fit(plot=False) + minuit.fixed = saved + + plt.clf() + self.plot_with_frame(from_fit, report_success) + self.canvas.draw_idle() + + def do_fit(self, plot=True): + report_success = self.fit() + for i, x in enumerate(self.parameters): + x.reset(val=minuit.values[i]) + if not plot: + return report_success + self.on_parameter_change(from_fit=True, report_success=report_success) + + def on_update_button_clicked(self): + for x in self.parameters: + x.slider.setTracking(self.update_button.isChecked()) + + def on_reset_button_clicked(self): + minuit.reset() + minuit.values = original_values + minuit.limits = original_limits + for i, x in enumerate(self.parameters): + x.reset(val=minuit.values[i], limits=original_limits[i]) + self.on_parameter_change() + + if run_event_loop: + app = QtWidgets.QApplication.instance() + if app is None: + app = QtWidgets.QApplication([]) + + widget = Widget() + widget.show() + app.exec() # this blocks the main thread + else: + return Widget() + + +@contextmanager +def _block_signals(*widgets): + for w in widgets: + w.blockSignals(True) + try: + yield + finally: + for w in widgets: + w.blockSignals(False) diff --git a/src/iminuit/util.py b/src/iminuit/util.py index 3db8bb51..ff18934d 100644 --- a/src/iminuit/util.py +++ b/src/iminuit/util.py @@ -1684,3 +1684,29 @@ def is_positive_definite(m: ArrayLike) -> bool: return False return True return False + + +def is_jupyter() -> bool: + try: + from IPython import get_ipython + + ip = get_ipython() + return ip.has_trait("kernel") + except ImportError: + return False + except AttributeError: + return False + return False + + +def _make_finite(x: float) -> float: + sign = -1 if x < 0 else 1 + if abs(x) == np.inf: + return sign * sys.float_info.max + return x + + +def _widget_guess_initial_step(val: float, vmin: float, vmax: float) -> float: + if np.isfinite(vmin) and np.isfinite(vmax): + return 1e-2 * (vmax - vmin) + return 1e-2 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..bbce2d19 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,15 @@ +import pytest +from unittest.mock import patch, MagicMock + + +@pytest.fixture +def mock_ipython(): + with patch("IPython.get_ipython") as mock_get_ipython: + mock_shell = MagicMock() + + def has_trait(name): + return True + + mock_shell.has_trait.side_effect = has_trait + mock_get_ipython.return_value = mock_shell + yield diff --git a/tests/test_draw.py b/tests/test_draw.py index b4e9e088..9cd3ad5f 100644 --- a/tests/test_draw.py +++ b/tests/test_draw.py @@ -2,8 +2,6 @@ from iminuit import Minuit from pathlib import Path import numpy as np -from numpy.testing import assert_allclose -import contextlib mpl = pytest.importorskip("matplotlib") plt = pytest.importorskip("matplotlib.pyplot") @@ -133,126 +131,3 @@ def test_mnmatrix_7(fig): m = Minuit(lambda x: abs(x) ** 2 + x**4 + 10 * x, x=0) m.migrad() m.draw_mnmatrix(cl=[1, 3]) - - -@pytest.mark.filterwarnings("ignore::DeprecationWarning") -def test_interactive(): - ipywidgets = pytest.importorskip("ipywidgets") - - def cost(a, b): - return a**2 + b**2 - - class Plot: - def __init__(self): - self.called = False - self.raises = False - - def __call__(self, args): - self.called = True - if self.raises: - raise ValueError("foo") - - @contextlib.contextmanager - def assert_call(self): - self.called = False - yield - assert self.called - - plot = Plot() - - m = Minuit(cost, 1, 1) - with pytest.raises(AttributeError, match="no visualize method"): - m.interactive(raise_on_exception=True) - - with plot.assert_call(): - out1 = m.interactive(plot) - assert isinstance(out1, ipywidgets.HBox) - - # manipulate state to also check this code - ui = out1.children[1] - header, parameters = ui.children - fit_button, update_button, reset_button, algo_select = header.children - with plot.assert_call(): - fit_button.click() - assert_allclose(m.values, (0, 0), atol=1e-5) - with plot.assert_call(): - reset_button.click() - assert_allclose(m.values, (1, 1), atol=1e-5) - - algo_select.value = "Scipy" - with plot.assert_call(): - fit_button.click() - - algo_select.value = "Simplex" - with plot.assert_call(): - fit_button.click() - - update_button.value = False - with plot.assert_call(): - # because of implementation details, we have to trigger the slider several times - for i in range(5): - parameters.children[0].slider.value = i # change first slider - parameters.children[0].fix.value = True - with plot.assert_call(): - parameters.children[0].fit.value = True - - class Cost: - def visualize(self, args): - return plot(args) - - def __call__(self, a, b): - return (a - 100) ** 2 + (b + 100) ** 2 - - c = Cost() - m = Minuit(c, 0, 0) - with plot.assert_call(): - out = m.interactive(raise_on_exception=True) - - # this should modify slider range - ui = out.children[1] - header, parameters = ui.children - fit_button, update_button, reset_button, algo_select = header.children - assert parameters.children[0].slider.max == 1 - assert parameters.children[1].slider.min == -1 - with plot.assert_call(): - fit_button.click() - assert_allclose(m.values, (100, -100), atol=1e-5) - # this should trigger an exception - plot.raises = True - with plot.assert_call(): - fit_button.click() - - -@pytest.mark.filterwarnings("ignore::DeprecationWarning") -def test_interactive_raises(): - pytest.importorskip("ipywidgets") - - def raiser(args): - raise ValueError - - m = Minuit(lambda x, y: 0, 0, 1) - - # by default do not raise - m.interactive(raiser) - - with pytest.raises(ValueError): - m.interactive(raiser, raise_on_exception=True) - - -@pytest.mark.filterwarnings("ignore::DeprecationWarning") -def test_interactive_with_array_func(): - pytest.importorskip("ipywidgets") - - def cost(par): - return par[0] ** 2 + (par[1] / 2) ** 2 - - class TraceArgs: - nargs = 0 - - def __call__(self, par): - self.nargs = len(par) - - trace_args = TraceArgs() - m = Minuit(cost, (1, 2)) - m.interactive(trace_args) - assert trace_args.nargs > 0 diff --git a/tests/test_ipywidget.py b/tests/test_ipywidget.py new file mode 100644 index 00000000..5b511af5 --- /dev/null +++ b/tests/test_ipywidget.py @@ -0,0 +1,129 @@ +import pytest +from iminuit import Minuit +from numpy.testing import assert_allclose +import contextlib + +mpl = pytest.importorskip("matplotlib") +plt = pytest.importorskip("matplotlib.pyplot") +ipywidgets = pytest.importorskip("ipywidgets") + +mpl.use("Agg") + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +def test_interactive_ipywidgets(mock_ipython): + def cost(a, b): + return a**2 + b**2 + + class Plot: + def __init__(self): + self.called = False + self.raises = False + + def __call__(self, args): + self.called = True + if self.raises: + raise ValueError("foo") + + @contextlib.contextmanager + def assert_call(self): + self.called = False + yield + assert self.called + + plot = Plot() + + m = Minuit(cost, 1, 1) + + with pytest.raises(AttributeError, match="no visualize method"): + m.interactive(raise_on_exception=True) + + with plot.assert_call(): + out1 = m.interactive(plot) + assert isinstance(out1, ipywidgets.HBox) + + # manipulate state to also check this code + ui = out1.children[1] + header, parameters = ui.children + fit_button, update_button, reset_button, algo_select = header.children + with plot.assert_call(): + fit_button.click() + assert_allclose(m.values, (0, 0), atol=1e-5) + with plot.assert_call(): + reset_button.click() + assert_allclose(m.values, (1, 1), atol=1e-5) + + algo_select.value = "Scipy" + with plot.assert_call(): + fit_button.click() + + algo_select.value = "Simplex" + with plot.assert_call(): + fit_button.click() + + update_button.value = False + with plot.assert_call(): + # because of implementation details, we have to trigger the slider several times + for i in range(5): + parameters.children[0].slider.value = i # change first slider + parameters.children[0].fix.value = True + with plot.assert_call(): + parameters.children[0].fit.value = True + + class Cost: + def visualize(self, args): + return plot(args) + + def __call__(self, a, b): + return (a - 100) ** 2 + (b + 100) ** 2 + + c = Cost() + m = Minuit(c, 0, 0) + with plot.assert_call(): + out = m.interactive(raise_on_exception=True) + + # this should modify slider range + ui = out.children[1] + header, parameters = ui.children + fit_button, update_button, reset_button, algo_select = header.children + assert parameters.children[0].slider.max == 1 + assert parameters.children[1].slider.min == -1 + with plot.assert_call(): + fit_button.click() + assert_allclose(m.values, (100, -100), atol=1e-5) + # this should trigger an exception + plot.raises = True + with plot.assert_call(): + fit_button.click() + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +def test_interactive_ipywidgets_raises(mock_ipython): + def raiser(args): + raise ValueError + + m = Minuit(lambda x, y: 0, 0, 1) + + # by default do not raise + m.interactive(raiser) + + with pytest.raises(ValueError): + m.interactive(raiser, raise_on_exception=True) + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +def test_interactive_ipywidgets_with_array_func(mock_ipython): + def cost(par): + return par[0] ** 2 + (par[1] / 2) ** 2 + + class TraceArgs: + nargs = 0 + + def __call__(self, par): + self.nargs = len(par) + + trace_args = TraceArgs() + m = Minuit(cost, (1, 2)) + + m.interactive(trace_args) + assert trace_args.nargs > 0 diff --git a/tests/test_qtwidget.py b/tests/test_qtwidget.py new file mode 100644 index 00000000..9cb28589 --- /dev/null +++ b/tests/test_qtwidget.py @@ -0,0 +1,126 @@ +import pytest +from iminuit import Minuit +from numpy.testing import assert_allclose +import contextlib + +mpl = pytest.importorskip("matplotlib") +plt = pytest.importorskip("matplotlib.pyplot") +PyQt6 = pytest.importorskip("PyQt6") + +mpl.use("Agg") + + +def qtinteractive(m, plot=None, raise_on_exception=False, **kwargs): + from iminuit.qtwidget import make_widget + + return make_widget(m, plot, kwargs, raise_on_exception, run_event_loop=False) + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +def test_interactive_pyqt6(qtbot): + def cost(a, b): + return a**2 + b**2 + + class Plot: + def __init__(self): + self.called = False + self.raises = False + + def __call__(self, args): + self.called = True + if self.raises: + raise ValueError("foo") + + @contextlib.contextmanager + def assert_call(self): + self.called = False + yield + assert self.called + + plot = Plot() + + m = Minuit(cost, 1, 1) + + with plot.assert_call(): + mw1 = qtinteractive(m, plot) + qtbot.addWidget(mw1) + assert isinstance(mw1, PyQt6.QtWidgets.QWidget) + + # manipulate state to also check this code + with plot.assert_call(): + mw1.fit_button.click() + assert_allclose(m.values, (0, 0), atol=1e-5) + with plot.assert_call(): + mw1.reset_button.click() + assert_allclose(m.values, (1, 1), atol=1e-5) + + mw1.algo_choice.setCurrentText("Scipy") + with plot.assert_call(): + mw1.fit_button.click() + + mw1.algo_choice.setCurrentText("Simplex") + with plot.assert_call(): + mw1.fit_button.click() + + mw1.update_button.click() + with plot.assert_call(): + mw1.parameters[0].slider.valueChanged.emit(int(5e7)) + mw1.parameters[0].fix.click() + with plot.assert_call(): + mw1.parameters[0].fit.click() + + class Cost: + def visualize(self, args): + return plot(args) + + def __call__(self, a, b): + return (a - 100) ** 2 + (b + 100) ** 2 + + c = Cost() + m = Minuit(c, 0, 0) + with plot.assert_call(): + mw = qtinteractive(m, raise_on_exception=True) + qtbot.addWidget(mw) + + # this should modify slider range + assert mw.parameters[0].vmax == 1 + assert mw.parameters[1].vmin == -1 + with plot.assert_call(): + mw.fit_button.click() + assert_allclose(m.values, (100, -100), atol=1e-5) + # this should trigger an exception + # plot.raises = True + # with plot.assert_call(): + # mw.fit_button.click() + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +def test_interactive_pyqt6_raises(qtbot): + def raiser(args): + raise ValueError + + m = Minuit(lambda x, y: 0, 0, 1) + + # by default do not raise + qtinteractive(m, raiser) + + with pytest.raises(ValueError): + qtinteractive(m, raiser, raise_on_exception=True) + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +def test_interactive_pyqt6_with_array_func(qtbot): + def cost(par): + return par[0] ** 2 + (par[1] / 2) ** 2 + + class TraceArgs: + nargs = 0 + + def __call__(self, par): + self.nargs = len(par) + + trace_args = TraceArgs() + m = Minuit(cost, (1, 2)) + + qtinteractive(m, trace_args) + assert trace_args.nargs > 0 diff --git a/tests/test_without_ipywidgets.py b/tests/test_without_ipywidgets.py index fbd9b508..48445901 100644 --- a/tests/test_without_ipywidgets.py +++ b/tests/test_without_ipywidgets.py @@ -5,7 +5,7 @@ pytest.importorskip("ipywidgets") -def test_interactive(): +def test_interactive(mock_ipython): pytest.importorskip("matplotlib") import iminuit @@ -14,5 +14,5 @@ def test_interactive(): iminuit.Minuit(cost, 1).interactive() with hide_modules("ipywidgets", reload="iminuit.ipywidget"): - with pytest.raises(ModuleNotFoundError, match="Please install"): + with pytest.raises(ModuleNotFoundError, match="Please install ipywidgets"): iminuit.Minuit(cost, 1).interactive() diff --git a/tests/test_without_pyqt6.py b/tests/test_without_pyqt6.py new file mode 100644 index 00000000..84e49c2b --- /dev/null +++ b/tests/test_without_pyqt6.py @@ -0,0 +1,14 @@ +from iminuit._hide_modules import hide_modules +from iminuit.cost import LeastSquares +import pytest + + +def test_interactive(): + pytest.importorskip("matplotlib") + import iminuit + + cost = LeastSquares([1.1, 2.2], [3.3, 4.4], 1, lambda x, a: a * x) + + with hide_modules("PyQt6", reload="iminuit.qtwidget"): + with pytest.raises(ModuleNotFoundError, match="Please install PyQt6"): + iminuit.Minuit(cost, 1).interactive()