From ef83c36f4a3f574048d1b0237fd52c93ada4a729 Mon Sep 17 00:00:00 2001 From: Neil Vaytet Date: Tue, 26 Nov 2024 17:55:07 +0100 Subject: [PATCH 1/2] add cmap arg to result.plot --- src/tof/result.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/tof/result.py b/src/tof/result.py index 3ee3dde..ea071be 100644 --- a/src/tof/result.py +++ b/src/tof/result.py @@ -45,6 +45,7 @@ def _add_rays( birth_times: sc.Variable, distances: sc.Variable, cbar: bool = True, + cmap: str = 'gist_rainbow_r', wavelengths: Optional[sc.Variable] = None, wmin: Optional[sc.Variable] = None, wmax: Optional[sc.Variable] = None, @@ -62,7 +63,7 @@ def _add_rays( ) coll = LineCollection(segments) if wavelengths is not None: - coll.set_cmap(plt.cm.gist_rainbow_r) + coll.set_cmap(plt.colormaps[cmap]) coll.set_array(wavelengths.values) coll.set_norm(plt.Normalize(wmin.value, wmax.value)) if cbar: @@ -174,6 +175,7 @@ def _plot_visible_rays( cbar: bool, wmin: sc.Variable, wmax: sc.Variable, + cmap: str, ): da = furthest_detector.data['pulse', pulse_index] visible = da[~da.masks['blocked_by_others']] @@ -196,6 +198,7 @@ def _plot_visible_rays( wavelengths=wavelengths, wmin=wmin, wmax=wmax, + cmap=cmap, ) def _plot_blocked_rays( @@ -265,6 +268,7 @@ def plot( figsize: Optional[Tuple[float, float]] = None, ax: Optional[plt.Axes] = None, cbar: bool = True, + cmap: str = 'gist_rainbow_r', ) -> Plot: """ Plot the time-distance diagram for the instrument, including the rays of @@ -286,6 +290,8 @@ def plot( Axes to plot on. cbar: Show a colorbar for the wavelength if ``True``. + cmap: + Colormap to use for the wavelength colorbar. """ if ax is None: fig, ax = plt.subplots(figsize=figsize) @@ -311,6 +317,7 @@ def plot( cbar=cbar and (i == 0), wmin=wavelengths.min(), wmax=wavelengths.max(), + cmap=cmap, ) self._plot_pulse(pulse_index=i, ax=ax) From 16be454ede2446076e0d2f333f9a15beca2becce Mon Sep 17 00:00:00 2001 From: Neil Vaytet Date: Tue, 26 Nov 2024 17:56:19 +0100 Subject: [PATCH 2/2] add cmap test --- tests/result_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/result_test.py b/tests/result_test.py index 626436a..b8122a3 100644 --- a/tests/result_test.py +++ b/tests/result_test.py @@ -255,6 +255,12 @@ def test_result_all_neutrons_blocked_does_not_raise(): res.plot(blocked_rays=500) +def test_result_plot_cmap_does_not_raise(): + model = make_ess_model() + res = model.run() + res.plot(cmap='viridis') + + def test_result_repr_does_not_raise(): model = make_ess_model() res = model.run()