Skip to content

Commit

Permalink
Merge pull request #65 from scipp/add-cmap-arg
Browse files Browse the repository at this point in the history
Add cmap arg
  • Loading branch information
nvaytet authored Nov 26, 2024
2 parents cb489e1 + e5e57cb commit ec9a76f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/tof/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _add_rays(
birth_times: sc.Variable,
distances: sc.Variable,
cbar: bool = True,
cmap: str = 'gist_rainbow_r',
wavelengths: Optional[sc.Variable] = None,
wmin: Optional[sc.Variable] = None,
wmax: Optional[sc.Variable] = None,
Expand All @@ -62,7 +63,7 @@ def _add_rays(
)
coll = LineCollection(segments)
if wavelengths is not None:
coll.set_cmap(plt.cm.gist_rainbow_r)
coll.set_cmap(plt.colormaps[cmap])
coll.set_array(wavelengths.values)
coll.set_norm(plt.Normalize(wmin.value, wmax.value))
if cbar:
Expand Down Expand Up @@ -174,6 +175,7 @@ def _plot_visible_rays(
cbar: bool,
wmin: sc.Variable,
wmax: sc.Variable,
cmap: str,
):
da = furthest_detector.data['pulse', pulse_index]
visible = da[~da.masks['blocked_by_others']]
Expand All @@ -196,6 +198,7 @@ def _plot_visible_rays(
wavelengths=wavelengths,
wmin=wmin,
wmax=wmax,
cmap=cmap,
)

def _plot_blocked_rays(
Expand Down Expand Up @@ -265,6 +268,7 @@ def plot(
figsize: Optional[Tuple[float, float]] = None,
ax: Optional[plt.Axes] = None,
cbar: bool = True,
cmap: str = 'gist_rainbow_r',
) -> Plot:
"""
Plot the time-distance diagram for the instrument, including the rays of
Expand All @@ -286,6 +290,8 @@ def plot(
Axes to plot on.
cbar:
Show a colorbar for the wavelength if ``True``.
cmap:
Colormap to use for the wavelength colorbar.
"""
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
Expand All @@ -311,6 +317,7 @@ def plot(
cbar=cbar and (i == 0),
wmin=wavelengths.min(),
wmax=wavelengths.max(),
cmap=cmap,
)
self._plot_pulse(pulse_index=i, ax=ax)

Expand Down
6 changes: 6 additions & 0 deletions tests/result_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,12 @@ def test_result_all_neutrons_blocked_does_not_raise():
res.plot(blocked_rays=500)


def test_result_plot_cmap_does_not_raise():
model = make_ess_model()
res = model.run()
res.plot(cmap='viridis')


def test_result_repr_does_not_raise():
model = make_ess_model()
res = model.run()
Expand Down

0 comments on commit ec9a76f

Please sign in to comment.