diff --git a/python/pfs/drp/stella/datamodel/pfsFiberArray.py b/python/pfs/drp/stella/datamodel/pfsFiberArray.py index 3d941cfa4..cc916b72f 100644 --- a/python/pfs/drp/stella/datamodel/pfsFiberArray.py +++ b/python/pfs/drp/stella/datamodel/pfsFiberArray.py @@ -23,7 +23,12 @@ def __itruediv__(self, rhs: ArrayLike) -> "PfsSimpleSpectrum": self.flux /= rhs return self - def plot(self, ignorePixelMask: Optional[int] = None, show: bool = True) -> Tuple["Figure", "Axes"]: + def plot(self, ignorePixelMask: Optional[int] = None, + figure: Optional[Figure] if TYPE_CHECKING else [] = None, + ax: Optional[Axes] if TYPE_CHECKING else [] = None, + trimToUsable: Optional[bool] = False, + title: Optional[str] = None, + show: bool = True) -> Tuple["Figure", "Axes"]: """Plot the object spectrum Parameters @@ -37,7 +42,7 @@ def plot(self, ignorePixelMask: Optional[int] = None, show: bool = True) -> Tupl ------- figure : `matplotlib.Figure` Figure containing the plot. - axes : `matplotlib.Axes` + ax : `matplotlib.Axes` Axes containing the plot. """ import matplotlib.pyplot as plt @@ -46,21 +51,36 @@ def plot(self, ignorePixelMask: Optional[int] = None, show: bool = True) -> Tupl if ignorePixelMask is None: ignorePixelMask = self.flags.get("NO_DATA") - figure, axes = plt.subplots() + if figure is None: + if ax is None: + figure, axs = plt.subplots(squeeze=False) + ax = axs[0] + else: + figure = ax.get_figure() + elif ax is None: + ax = figure.gca() + good = (self.mask & ignorePixelMask) == 0 - axes.plot(self.wavelength[good], self.flux[good], 'k-', label="Flux") for start, stop in contiguous_regions(~good): if stop >= self.wavelength.size: stop = self.wavelength.size - 1 - axes.axvspan(self.wavelength[start], self.wavelength[stop], color="grey", alpha=0.1) - - axes.set_xlabel("Wavelength (nm)") - axes.set_ylabel("Flux (nJy)") - axes.set_title(str(self.getIdentity())) + if trimToUsable: + good &= ~((self.wavelength[start] < self.wavelength) & + (self.wavelength < self.wavelength[stop])) + else: + ax.axvspan(self.wavelength[start], self.wavelength[stop], color="grey", alpha=0.1) + + ax.plot(self.wavelength[good], self.flux[good], 'k-', label="Flux") + + ax.set_xlabel("Wavelength (nm)") + ax.set_ylabel("Flux (nJy)") + if title is None: + title = str(self.getIdentity()) + ax.set_title(title) if show: figure.show() - return figure, axes + return figure, ax def resample(self, wavelength: np.ndarray) -> "PfsSimpleSpectrum": """Resampled the spectrum in wavelength @@ -99,6 +119,10 @@ def plot( plotSky: bool = True, plotErrors: bool = True, ignorePixelMask: Optional[int] = None, + figure: Optional[Figure] if TYPE_CHECKING else [] = None, + ax: Optional[Axes] if TYPE_CHECKING else [] = None, + trimToUsable: Optional[bool] = False, + title: Optional[str] = None, show: bool = True, ) -> Tuple["Figure", "Axes"]: """Plot the object spectrum @@ -123,12 +147,25 @@ def plot( """ if ignorePixelMask is None: ignorePixelMask = self.flags.get("NO_DATA") - figure, axes = super().plot(ignorePixelMask=ignorePixelMask, show=False) + figure, axes = super().plot(figure=figure, ax=ax, ignorePixelMask=ignorePixelMask, + trimToUsable=trimToUsable, title=title, + show=False) good = (self.mask & ignorePixelMask) == 0 + + if trimToUsable: + from matplotlib.cbook import contiguous_regions + for start, stop in contiguous_regions(~good): + if stop >= self.wavelength.size: + stop = self.wavelength.size - 1 + good &= ~((self.wavelength[start] < self.wavelength) & + (self.wavelength < self.wavelength[stop])) + if plotSky: - axes.plot(self.wavelength[good], self.sky[good], 'b-', label="Sky") + axes.plot(self.wavelength[good], self.sky[good], 'b-', alpha=0.5, + label="Sky") if plotErrors: - axes.plot(self.wavelength[good], np.sqrt(self.variance[good]), 'r-', label="Flux errors") + axes.plot(self.wavelength[good], np.sqrt(self.variance[good]), 'r-', alpha=0.5, + label="Flux errors") if show: figure.show() return figure, axes diff --git a/python/pfs/drp/stella/datamodel/pfsFiberArraySet.py b/python/pfs/drp/stella/datamodel/pfsFiberArraySet.py index 52c934734..620f8e926 100644 --- a/python/pfs/drp/stella/datamodel/pfsFiberArraySet.py +++ b/python/pfs/drp/stella/datamodel/pfsFiberArraySet.py @@ -39,7 +39,7 @@ def __itruediv__(self, rhs): return self.__imul__(1.0/rhs) def plot(self, fiberId=None, usePixels=False, ignorePixelMask=0x0, normalized=False, show=True, - figure=None, axes=None): + figure=None, ax=None): """Plot the spectra Parameters @@ -56,14 +56,14 @@ def plot(self, fiberId=None, usePixels=False, ignorePixelMask=0x0, normalized=Fa Show the plot? figure : `matplotlib.Figure` or ``None`` The figure to use - axes : `matplotlib.Axes` or ``None`` - The axes to use. + ax : `matplotlib.Axes` or ``None`` + The x/y axes to use. Returns ------- figure : `matplotlib.Figure` Figure containing the plot. - axes : `matplotlib.Axes` + ax : `matplotlib.Axes` Axes containing the plot. """ import matplotlib.pyplot as plt @@ -79,12 +79,13 @@ def plot(self, fiberId=None, usePixels=False, ignorePixelMask=0x0, normalized=Fa xLabel = "Wavelength (nm)" if figure is None: - if axes is None: - figure, axes = plt.subplots() + if ax is None: + figure, axs = plt.subplots(squeeze=False) + ax = axs[0] else: - figure = axes.get_figure() - elif axes is None: - axes = figure.gca() + figure = ax.get_figure() + elif ax is None: + ax = figure.gca() colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(fiberId))) for ff, cc in zip(fiberId, colors): @@ -100,14 +101,14 @@ def plot(self, fiberId=None, usePixels=False, ignorePixelMask=0x0, normalized=Fa if normalized: with np.errstate(invalid="ignore", divide="ignore"): flux /= self.norm[index][good] - axes.plot(lam[good], flux, ls="solid", color=cc, label=str(ff)) + ax.plot(lam[good], flux, ls="solid", color=cc, label=str(ff)) - axes.set_xlabel(xLabel) - axes.set_ylabel(self._ylabel) + ax.set_xlabel(xLabel) + ax.set_ylabel(self._ylabel) if show: figure.show() - return figure, axes + return figure, ax def resample(self, wavelength, fiberId=None): """Construct a new PfsFiberArraySet resampled to a common wavelength vector diff --git a/python/pfs/drp/stella/utils/display.py b/python/pfs/drp/stella/utils/display.py index 27bed7e20..06cd99bd5 100644 --- a/python/pfs/drp/stella/utils/display.py +++ b/python/pfs/drp/stella/utils/display.py @@ -437,9 +437,6 @@ def showDetectorMap(display, pfsConfig, detMap, width=100, zoom=0, xcen=None, fi SuNSS = TargetType.SUNSS_IMAGING in pfsConfig.targetType showAll = False - if xcen is None and len(fiberIds) == 0: - xcen = 2000 - if xcen is None: if fiberIds is None: fiberIds = detMap.fiberId @@ -496,7 +493,7 @@ def showDetectorMap(display, pfsConfig, detMap, width=100, zoom=0, xcen=None, fi color = 'green' if SuNSS and imagingFiber else 'red' fiberX = detMap.getXCenter(fid, height//2) - if showAll or len(fiberIds) > 1 or np.abs(fiberX - xcen) < width/2: + if showAll or (fiberIds is not None and len(fiberIds) > 1) or np.abs(fiberX - xcen) < width/2: fiberX = detMap.getXCenter(fid) plt.plot(fiberX[::20], y[::20], ls=ls, alpha=alpha, label=f"{fid}", color=color if showAll else None) diff --git a/python/pfs/drp/stella/utils/fiberThroughputs.py b/python/pfs/drp/stella/utils/fiberThroughputs.py index 73ab43e3e..055bef061 100644 --- a/python/pfs/drp/stella/utils/fiberThroughputs.py +++ b/python/pfs/drp/stella/utils/fiberThroughputs.py @@ -9,29 +9,41 @@ def selectWavelengthInterval(arm): """We'll measure the flux in [lam0, lam1]; if bkgd0 and bkgd1 are not None, measure the background The background will be measured in [bkgd0, lam0] + [lam1, bkgd1] """ - bkgd0, bkgd1 = None, None # the region to measure the background; may be None + intervals = [] if arm == 'b': lamc = 557.9 (lam0, lam1), (bkgd0, bkgd1) = (lamc - 0.5, lamc + 0.5), (lamc - 1.5, lamc + 1.5) + + intervals.append((lam0, lam1, bkgd0, bkgd1)) elif arm == 'r': - lam0, lam1 = 930.5, 933.1 + intervals.append((930.5, 933.1, None, None)) + intervals.append((947.25, 949, None, None)) elif arm == 'n': - lam0, lam1 = 1142, 1148 + intervals.append((1142, 1148, None, None)) else: raise RuntimeError(f"selectWavelengthInterval doesn't know about arm={arm}") - assert (bkgd0 is None) == (bkgd1 is None) + return intervals - return lam0, lam1, bkgd0, bkgd1 +def getWavelengthLabel(arm): + """Generate a label for the wavelength intervals used in arm""" + labels = [] + for lam0, lam1, b0, b1 in selectWavelengthInterval(arm): + labels.append(f"{lam0:.1f} < $\\lambda$ < {lam1:.1f}") -def showWavelengthInterval(arm): - lam0, lam1, bkgd0, bkgd1 = selectWavelengthInterval(arm) + if len(labels) == 1: + return labels[0] + else: + return "(" + "), (".join(labels) + ")" - if bkgd0: - plt.axvspan(bkgd0, bkgd1, color='black', alpha=0.05, zorder=-2) - plt.axvspan(lam0, lam1, color='black', alpha=0.05 if bkgd0 else 0.1, zorder=-1) +def showWavelengthInterval(arm): + for lam0, lam1, bkgd0, bkgd1 in selectWavelengthInterval(arm): + if bkgd0: + plt.axvspan(bkgd0, bkgd1, color='black', alpha=0.05, zorder=-2) + + plt.axvspan(lam0, lam1, color='black', alpha=0.05 if bkgd0 else 0.1, zorder=-1) def extractFlux(y, lam, l0, l1): @@ -39,24 +51,28 @@ def extractFlux(y, lam, l0, l1): def measureBkgd(y, lam, arm): - lam0, lam1, bkgd0, bkgd1 = selectWavelengthInterval(arm) + bkgds = [] + for lam0, lam1, bkgd0, bkgd1 in selectWavelengthInterval(arm): + if bkgd0: + bkgd = extractFlux(y, lam, bkgd0, lam0) + extractFlux(y, lam, lam1, bkgd1) + bkgd /= (bkgd1 - bkgd0 - (lam1 - lam0)) + else: + bkgd = 0*extractFlux(y, lam, lam0, lam1) - if bkgd0: - bkgd = extractFlux(y, lam, bkgd0, lam0) + extractFlux(y, lam, lam1, bkgd1) - bkgd /= (bkgd1 - bkgd0 - (lam1 - lam0)) - else: - return 0*extractFlux(y, lam, lam0, lam1) + bkgds.append(bkgd) - return bkgd + return bkgds def measureFlux(y, lam, arm): - lam0, lam1, bkgd0, bkgd1 = selectWavelengthInterval(arm) + fluxes = [] + for (lam0, lam1, bkgd0, bkgd1), bkgd in zip(selectWavelengthInterval(arm), measureBkgd(y, lam, arm)): + flux = extractFlux(y, lam, lam0, lam1) + flux -= (lam1 - lam0)*bkgd - flux = extractFlux(y, lam, lam0, lam1) - flux -= (lam1 - lam0)*measureBkgd(y, lam, arm) + fluxes.append(flux) - return flux + return np.sum(fluxes, axis=0) def estimateFiberThroughputs(butler, visits, arms="brn", what="flux", @@ -175,7 +191,7 @@ def estimateFiberThroughputs(butler, visits, arms="brn", what="flux", y = list(measureFlux(y, spec.wavelength, dataId["arm"])) - if True: + if False: # # Deal with fibres missing in spec; PIPE2D-1401 # We insert NaNs into the measurements in the proper places @@ -185,7 +201,7 @@ def estimateFiberThroughputs(butler, visits, arms="brn", what="flux", j = np.where(config.fiberId == fid)[0][0] y[j:j] = [np.NaN] - c.append(y) + c.append(y) visitC[what][dataId["arm"]][dataId["visit"]] = np.array(sum(c, [])) visitConfig[what][dataId["arm"]][dataId["visit"]] = pfsConfig[ll] @@ -259,9 +275,9 @@ def plotThroughputs(cache, visits, arms, what="flux", refVisit=-1, showHome=Fals raise RuntimeError(f"Unable to read value of {what} from cache for {dataId}") title = [f"{dataId['visit']} {dataId['arm']}"] - lam0, lam1 = selectWavelengthInterval(dataId["arm"])[:2] + lam0, lam1 = selectWavelengthInterval(dataId["arm"])[0][:2] # XXX just the first pair colorbarlabel = [f"{'relative' if refVisit > 0 else ''}flux in " - f"{lam0:.1f} < $\\lambda$ < {lam1:.1f}"] + + getWavelengthLabel(arm)] title.append(f"[{pfsConfig.designName}]") title.append("\n") title.append(dict(flux="", @@ -443,8 +459,7 @@ def throughputPerSpectrograph(cache, visit, arm, what="flux", title=""): II = plt.imshow(im, origin="lower", aspect="auto", interpolation='none', vmin=vmin, vmax=vmax, extent=(0.5, im.shape[1] + 0.5, 0.5, 4.5)) - lam0, lam1 = selectWavelengthInterval(arm)[:2] - plt.colorbar(II, label=f"flux in {lam0:.1f} < $\\lambda$ < {lam1:.1f}") + plt.colorbar(II, label=f"flux in {getWavelengthLabel(arm)}") plt.xlabel("Fibre Hole") plt.ylabel("spectrograph")