Skip to content

Commit

Permalink
Merge branch 'tickets/PIPE2D-1477'
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertLuptonTheGood committed Jun 19, 2024
2 parents a2cd593 + c84397b commit 5637f14
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 57 deletions.
63 changes: 50 additions & 13 deletions python/pfs/drp/stella/datamodel/pfsFiberArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ def __itruediv__(self, rhs: ArrayLike) -> "PfsSimpleSpectrum":
self.flux /= rhs
return self

def plot(self, ignorePixelMask: Optional[int] = None, show: bool = True) -> Tuple["Figure", "Axes"]:
def plot(self, ignorePixelMask: Optional[int] = None,
figure: Optional[Figure] if TYPE_CHECKING else [] = None,
ax: Optional[Axes] if TYPE_CHECKING else [] = None,
trimToUsable: Optional[bool] = False,
title: Optional[str] = None,
show: bool = True) -> Tuple["Figure", "Axes"]:
"""Plot the object spectrum
Parameters
Expand All @@ -37,7 +42,7 @@ def plot(self, ignorePixelMask: Optional[int] = None, show: bool = True) -> Tupl
-------
figure : `matplotlib.Figure`
Figure containing the plot.
axes : `matplotlib.Axes`
ax : `matplotlib.Axes`
Axes containing the plot.
"""
import matplotlib.pyplot as plt
Expand All @@ -46,21 +51,36 @@ def plot(self, ignorePixelMask: Optional[int] = None, show: bool = True) -> Tupl
if ignorePixelMask is None:
ignorePixelMask = self.flags.get("NO_DATA")

figure, axes = plt.subplots()
if figure is None:
if ax is None:
figure, axs = plt.subplots(squeeze=False)
ax = axs[0]
else:
figure = ax.get_figure()
elif ax is None:
ax = figure.gca()

good = (self.mask & ignorePixelMask) == 0
axes.plot(self.wavelength[good], self.flux[good], 'k-', label="Flux")

for start, stop in contiguous_regions(~good):
if stop >= self.wavelength.size:
stop = self.wavelength.size - 1
axes.axvspan(self.wavelength[start], self.wavelength[stop], color="grey", alpha=0.1)

axes.set_xlabel("Wavelength (nm)")
axes.set_ylabel("Flux (nJy)")
axes.set_title(str(self.getIdentity()))
if trimToUsable:
good &= ~((self.wavelength[start] < self.wavelength) &
(self.wavelength < self.wavelength[stop]))
else:
ax.axvspan(self.wavelength[start], self.wavelength[stop], color="grey", alpha=0.1)

ax.plot(self.wavelength[good], self.flux[good], 'k-', label="Flux")

ax.set_xlabel("Wavelength (nm)")
ax.set_ylabel("Flux (nJy)")
if title is None:
title = str(self.getIdentity())
ax.set_title(title)
if show:
figure.show()
return figure, axes
return figure, ax

def resample(self, wavelength: np.ndarray) -> "PfsSimpleSpectrum":
"""Resampled the spectrum in wavelength
Expand Down Expand Up @@ -99,6 +119,10 @@ def plot(
plotSky: bool = True,
plotErrors: bool = True,
ignorePixelMask: Optional[int] = None,
figure: Optional[Figure] if TYPE_CHECKING else [] = None,
ax: Optional[Axes] if TYPE_CHECKING else [] = None,
trimToUsable: Optional[bool] = False,
title: Optional[str] = None,
show: bool = True,
) -> Tuple["Figure", "Axes"]:
"""Plot the object spectrum
Expand All @@ -123,12 +147,25 @@ def plot(
"""
if ignorePixelMask is None:
ignorePixelMask = self.flags.get("NO_DATA")
figure, axes = super().plot(ignorePixelMask=ignorePixelMask, show=False)
figure, axes = super().plot(figure=figure, ax=ax, ignorePixelMask=ignorePixelMask,
trimToUsable=trimToUsable, title=title,
show=False)
good = (self.mask & ignorePixelMask) == 0

if trimToUsable:
from matplotlib.cbook import contiguous_regions
for start, stop in contiguous_regions(~good):
if stop >= self.wavelength.size:
stop = self.wavelength.size - 1
good &= ~((self.wavelength[start] < self.wavelength) &
(self.wavelength < self.wavelength[stop]))

if plotSky:
axes.plot(self.wavelength[good], self.sky[good], 'b-', label="Sky")
axes.plot(self.wavelength[good], self.sky[good], 'b-', alpha=0.5,
label="Sky")
if plotErrors:
axes.plot(self.wavelength[good], np.sqrt(self.variance[good]), 'r-', label="Flux errors")
axes.plot(self.wavelength[good], np.sqrt(self.variance[good]), 'r-', alpha=0.5,
label="Flux errors")
if show:
figure.show()
return figure, axes
Expand Down
27 changes: 14 additions & 13 deletions python/pfs/drp/stella/datamodel/pfsFiberArraySet.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __itruediv__(self, rhs):
return self.__imul__(1.0/rhs)

def plot(self, fiberId=None, usePixels=False, ignorePixelMask=0x0, normalized=False, show=True,
figure=None, axes=None):
figure=None, ax=None):
"""Plot the spectra
Parameters
Expand All @@ -56,14 +56,14 @@ def plot(self, fiberId=None, usePixels=False, ignorePixelMask=0x0, normalized=Fa
Show the plot?
figure : `matplotlib.Figure` or ``None``
The figure to use
axes : `matplotlib.Axes` or ``None``
The axes to use.
ax : `matplotlib.Axes` or ``None``
The x/y axes to use.
Returns
-------
figure : `matplotlib.Figure`
Figure containing the plot.
axes : `matplotlib.Axes`
ax : `matplotlib.Axes`
Axes containing the plot.
"""
import matplotlib.pyplot as plt
Expand All @@ -79,12 +79,13 @@ def plot(self, fiberId=None, usePixels=False, ignorePixelMask=0x0, normalized=Fa
xLabel = "Wavelength (nm)"

if figure is None:
if axes is None:
figure, axes = plt.subplots()
if ax is None:
figure, axs = plt.subplots(squeeze=False)
ax = axs[0]
else:
figure = axes.get_figure()
elif axes is None:
axes = figure.gca()
figure = ax.get_figure()
elif ax is None:
ax = figure.gca()

colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(fiberId)))
for ff, cc in zip(fiberId, colors):
Expand All @@ -100,14 +101,14 @@ def plot(self, fiberId=None, usePixels=False, ignorePixelMask=0x0, normalized=Fa
if normalized:
with np.errstate(invalid="ignore", divide="ignore"):
flux /= self.norm[index][good]
axes.plot(lam[good], flux, ls="solid", color=cc, label=str(ff))
ax.plot(lam[good], flux, ls="solid", color=cc, label=str(ff))

axes.set_xlabel(xLabel)
axes.set_ylabel(self._ylabel)
ax.set_xlabel(xLabel)
ax.set_ylabel(self._ylabel)

if show:
figure.show()
return figure, axes
return figure, ax

def resample(self, wavelength, fiberId=None):
"""Construct a new PfsFiberArraySet resampled to a common wavelength vector
Expand Down
5 changes: 1 addition & 4 deletions python/pfs/drp/stella/utils/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,9 +437,6 @@ def showDetectorMap(display, pfsConfig, detMap, width=100, zoom=0, xcen=None, fi
SuNSS = TargetType.SUNSS_IMAGING in pfsConfig.targetType

showAll = False
if xcen is None and len(fiberIds) == 0:
xcen = 2000

if xcen is None:
if fiberIds is None:
fiberIds = detMap.fiberId
Expand Down Expand Up @@ -496,7 +493,7 @@ def showDetectorMap(display, pfsConfig, detMap, width=100, zoom=0, xcen=None, fi
color = 'green' if SuNSS and imagingFiber else 'red'

fiberX = detMap.getXCenter(fid, height//2)
if showAll or len(fiberIds) > 1 or np.abs(fiberX - xcen) < width/2:
if showAll or (fiberIds is not None and len(fiberIds) > 1) or np.abs(fiberX - xcen) < width/2:
fiberX = detMap.getXCenter(fid)
plt.plot(fiberX[::20], y[::20], ls=ls, alpha=alpha, label=f"{fid}",
color=color if showAll else None)
Expand Down
69 changes: 42 additions & 27 deletions python/pfs/drp/stella/utils/fiberThroughputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,54 +9,70 @@ def selectWavelengthInterval(arm):
"""We'll measure the flux in [lam0, lam1]; if bkgd0 and bkgd1 are not None, measure the background
The background will be measured in [bkgd0, lam0] + [lam1, bkgd1]
"""
bkgd0, bkgd1 = None, None # the region to measure the background; may be None
intervals = []
if arm == 'b':
lamc = 557.9
(lam0, lam1), (bkgd0, bkgd1) = (lamc - 0.5, lamc + 0.5), (lamc - 1.5, lamc + 1.5)

intervals.append((lam0, lam1, bkgd0, bkgd1))
elif arm == 'r':
lam0, lam1 = 930.5, 933.1
intervals.append((930.5, 933.1, None, None))
intervals.append((947.25, 949, None, None))
elif arm == 'n':
lam0, lam1 = 1142, 1148
intervals.append((1142, 1148, None, None))
else:
raise RuntimeError(f"selectWavelengthInterval doesn't know about arm={arm}")

assert (bkgd0 is None) == (bkgd1 is None)
return intervals

return lam0, lam1, bkgd0, bkgd1

def getWavelengthLabel(arm):
"""Generate a label for the wavelength intervals used in arm"""
labels = []
for lam0, lam1, b0, b1 in selectWavelengthInterval(arm):
labels.append(f"{lam0:.1f} < $\\lambda$ < {lam1:.1f}")

def showWavelengthInterval(arm):
lam0, lam1, bkgd0, bkgd1 = selectWavelengthInterval(arm)
if len(labels) == 1:
return labels[0]
else:
return "(" + "), (".join(labels) + ")"

if bkgd0:
plt.axvspan(bkgd0, bkgd1, color='black', alpha=0.05, zorder=-2)

plt.axvspan(lam0, lam1, color='black', alpha=0.05 if bkgd0 else 0.1, zorder=-1)
def showWavelengthInterval(arm):
for lam0, lam1, bkgd0, bkgd1 in selectWavelengthInterval(arm):
if bkgd0:
plt.axvspan(bkgd0, bkgd1, color='black', alpha=0.05, zorder=-2)

plt.axvspan(lam0, lam1, color='black', alpha=0.05 if bkgd0 else 0.1, zorder=-1)


def extractFlux(y, lam, l0, l1):
return np.nansum(np.where((lam > l0) & (lam < l1), y, np.NaN), axis=1)


def measureBkgd(y, lam, arm):
lam0, lam1, bkgd0, bkgd1 = selectWavelengthInterval(arm)
bkgds = []
for lam0, lam1, bkgd0, bkgd1 in selectWavelengthInterval(arm):
if bkgd0:
bkgd = extractFlux(y, lam, bkgd0, lam0) + extractFlux(y, lam, lam1, bkgd1)
bkgd /= (bkgd1 - bkgd0 - (lam1 - lam0))
else:
bkgd = 0*extractFlux(y, lam, lam0, lam1)

if bkgd0:
bkgd = extractFlux(y, lam, bkgd0, lam0) + extractFlux(y, lam, lam1, bkgd1)
bkgd /= (bkgd1 - bkgd0 - (lam1 - lam0))
else:
return 0*extractFlux(y, lam, lam0, lam1)
bkgds.append(bkgd)

return bkgd
return bkgds


def measureFlux(y, lam, arm):
lam0, lam1, bkgd0, bkgd1 = selectWavelengthInterval(arm)
fluxes = []
for (lam0, lam1, bkgd0, bkgd1), bkgd in zip(selectWavelengthInterval(arm), measureBkgd(y, lam, arm)):
flux = extractFlux(y, lam, lam0, lam1)
flux -= (lam1 - lam0)*bkgd

flux = extractFlux(y, lam, lam0, lam1)
flux -= (lam1 - lam0)*measureBkgd(y, lam, arm)
fluxes.append(flux)

return flux
return np.sum(fluxes, axis=0)


def estimateFiberThroughputs(butler, visits, arms="brn", what="flux",
Expand Down Expand Up @@ -175,7 +191,7 @@ def estimateFiberThroughputs(butler, visits, arms="brn", what="flux",

y = list(measureFlux(y, spec.wavelength, dataId["arm"]))

if True:
if False:
#
# Deal with fibres missing in spec; PIPE2D-1401
# We insert NaNs into the measurements in the proper places
Expand All @@ -185,7 +201,7 @@ def estimateFiberThroughputs(butler, visits, arms="brn", what="flux",
j = np.where(config.fiberId == fid)[0][0]
y[j:j] = [np.NaN]

c.append(y)
c.append(y)

visitC[what][dataId["arm"]][dataId["visit"]] = np.array(sum(c, []))
visitConfig[what][dataId["arm"]][dataId["visit"]] = pfsConfig[ll]
Expand Down Expand Up @@ -259,9 +275,9 @@ def plotThroughputs(cache, visits, arms, what="flux", refVisit=-1, showHome=Fals
raise RuntimeError(f"Unable to read value of {what} from cache for {dataId}")

title = [f"{dataId['visit']} {dataId['arm']}"]
lam0, lam1 = selectWavelengthInterval(dataId["arm"])[:2]
lam0, lam1 = selectWavelengthInterval(dataId["arm"])[0][:2] # XXX just the first pair
colorbarlabel = [f"{'relative' if refVisit > 0 else ''}flux in "
f"{lam0:.1f} < $\\lambda$ < {lam1:.1f}"]
+ getWavelengthLabel(arm)]
title.append(f"[{pfsConfig.designName}]")
title.append("\n")
title.append(dict(flux="",
Expand Down Expand Up @@ -443,8 +459,7 @@ def throughputPerSpectrograph(cache, visit, arm, what="flux", title=""):
II = plt.imshow(im, origin="lower", aspect="auto", interpolation='none', vmin=vmin, vmax=vmax,
extent=(0.5, im.shape[1] + 0.5, 0.5, 4.5))

lam0, lam1 = selectWavelengthInterval(arm)[:2]
plt.colorbar(II, label=f"flux in {lam0:.1f} < $\\lambda$ < {lam1:.1f}")
plt.colorbar(II, label=f"flux in {getWavelengthLabel(arm)}")

plt.xlabel("Fibre Hole")
plt.ylabel("spectrograph")
Expand Down

0 comments on commit 5637f14

Please sign in to comment.