Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cmap arg #65

Merged
merged 3 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/tof/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _add_rays(
birth_times: sc.Variable,
distances: sc.Variable,
cbar: bool = True,
cmap: str = 'gist_rainbow_r',
wavelengths: Optional[sc.Variable] = None,
wmin: Optional[sc.Variable] = None,
wmax: Optional[sc.Variable] = None,
Expand All @@ -62,7 +63,7 @@ def _add_rays(
)
coll = LineCollection(segments)
if wavelengths is not None:
coll.set_cmap(plt.cm.gist_rainbow_r)
coll.set_cmap(plt.colormaps[cmap])
coll.set_array(wavelengths.values)
coll.set_norm(plt.Normalize(wmin.value, wmax.value))
if cbar:
Expand Down Expand Up @@ -174,6 +175,7 @@ def _plot_visible_rays(
cbar: bool,
wmin: sc.Variable,
wmax: sc.Variable,
cmap: str,
):
da = furthest_detector.data['pulse', pulse_index]
visible = da[~da.masks['blocked_by_others']]
Expand All @@ -196,6 +198,7 @@ def _plot_visible_rays(
wavelengths=wavelengths,
wmin=wmin,
wmax=wmax,
cmap=cmap,
)

def _plot_blocked_rays(
Expand Down Expand Up @@ -265,6 +268,7 @@ def plot(
figsize: Optional[Tuple[float, float]] = None,
ax: Optional[plt.Axes] = None,
cbar: bool = True,
cmap: str = 'gist_rainbow_r',
) -> Plot:
"""
Plot the time-distance diagram for the instrument, including the rays of
Expand All @@ -286,6 +290,8 @@ def plot(
Axes to plot on.
cbar:
Show a colorbar for the wavelength if ``True``.
cmap:
Colormap to use for the wavelength colorbar.
"""
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
Expand All @@ -311,6 +317,7 @@ def plot(
cbar=cbar and (i == 0),
wmin=wavelengths.min(),
wmax=wavelengths.max(),
cmap=cmap,
)
self._plot_pulse(pulse_index=i, ax=ax)

Expand Down
6 changes: 6 additions & 0 deletions tests/result_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,12 @@ def test_result_all_neutrons_blocked_does_not_raise():
res.plot(blocked_rays=500)


def test_result_plot_cmap_does_not_raise():
model = make_ess_model()
res = model.run()
res.plot(cmap='viridis')


def test_result_repr_does_not_raise():
model = make_ess_model()
res = model.run()
Expand Down