diff --git a/src/tof/result.py b/src/tof/result.py index 4f0a119..5b4d6d0 100644 --- a/src/tof/result.py +++ b/src/tof/result.py @@ -297,22 +297,25 @@ def plot( fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() - furthest_detector = max(self._detectors.values(), key=lambda d: d.distance) + furthest_component = max( + chain(self._choppers.values(), self._detectors.values()), + key=lambda x: x.distance, + ) wavelengths = sc.DataArray( - data=furthest_detector.data.coords['wavelength'], - masks=furthest_detector.data.masks, + data=furthest_component.data.coords['wavelength'], + masks=furthest_component.data.masks, ) for i in range(self._source.data.sizes['pulse']): self._plot_blocked_rays( blocked_rays=blocked_rays, pulse_index=i, - furthest_detector=furthest_detector, + furthest_detector=furthest_component, ax=ax, ) self._plot_visible_rays( max_rays=max_rays, pulse_index=i, - furthest_detector=furthest_detector, + furthest_detector=furthest_component, ax=ax, cbar=cbar and (i == 0), wmin=wavelengths.min(), @@ -321,9 +324,9 @@ def plot( ) self._plot_pulse(pulse_index=i, ax=ax) - det_data = furthest_detector.toas.visible.data - if sum(da.sum().value for da in det_data.values()) > 0: - times = (da.coords['toa'].max() for da in det_data.values()) + comp_data = furthest_component.toas.visible.data + if sum(da.sum().value for da in comp_data.values()) > 0: + times = (da.coords['toa'].max() for da in comp_data.values()) else: times = (ch.close_times.max() for ch in self._choppers.values()) toa_max = reduce(max, times).value @@ -363,7 +366,7 @@ def plot( inches = fig.get_size_inches() fig.set_size_inches( ( - min(inches[0] * furthest_detector.data.sizes['pulse'], 12.0), + min(inches[0] * furthest_component.data.sizes['pulse'], 12.0), inches[1], ) ) diff --git a/tests/result_test.py b/tests/result_test.py index b8122a3..2cf485c 100644 --- a/tests/result_test.py +++ b/tests/result_test.py @@ -261,6 +261,15 @@ def test_result_plot_cmap_does_not_raise(): res.plot(cmap='viridis') +def test_result_plot_no_detectors_does_not_raise(): + model = make_ess_model() + model._detectors = {} + res = model.run() + res.plot() + res.plot(max_rays=5000) + res.plot(max_rays=50, blocked_rays=3000) + + def test_result_repr_does_not_raise(): model = make_ess_model() res = model.run()