diff --git a/pahfit/base.py b/pahfit/base.py index a0ff47e..0a55a88 100644 --- a/pahfit/base.py +++ b/pahfit/base.py @@ -1,5 +1,4 @@ import numpy as np -import matplotlib as mpl from pahfit.instrument import within_segment, fwhm from pahfit.errors import PAHFITModelError @@ -265,196 +264,6 @@ def model_from_param_info(param_info): return model - @staticmethod - def plot(axs, x, y, yerr, model, model_samples=1000, scalefac_resid=2): - """ - Plot model using axis object. - - Parameters - ---------- - axs : matplotlib.axis objects - where to put the plot - x : floats - wavelength points - y : floats - observed spectrum - yerr: floats - observed spectrum uncertainties - model : PAHFITBase model (astropy modeling CompoundModel) - model giving all the components and parameters - model_samples : int - Total number of wavelength points to allocate to the model display - scalefac_resid : float - Factor multiplying the standard deviation of the residuals to adjust plot limits - """ - # remove units if they are present - if hasattr(x, "value"): - x = x.value - if hasattr(y, "value"): - y = y.value - if hasattr(yerr, "value"): - yerr = yerr.value - - # Fine x samples for model fit - x_mod = np.logspace(np.log10(min(x)), np.log10(max(x)), model_samples) - - # spectrum and best fit model - ax = axs[0] - ax.set_yscale("linear") - ax.set_xscale("log") - ax.minorticks_on() - ax.tick_params(axis="both", - which="major", - top="on", - right="on", - direction="in", - length=10) - ax.tick_params(axis="both", - which="minor", - top="on", - right="on", - direction="in", - length=5) - - ax_att = ax.twinx() # axis for plotting the extinction curve - ax_att.tick_params(which="minor", direction="in", length=5) - ax_att.tick_params(which="major", direction="in", length=10) - ax_att.minorticks_on() - - # get the extinction model (probably a better way to do this) - ext_model = None - for cmodel in model: - if isinstance(cmodel, S07_attenuation): - ext_model = cmodel(x_mod) - - # get additional extinction components that can be - # characterized by functional forms (Drude profile in this case) - for cmodel in model: - if isinstance(cmodel, att_Drude1D): - if ext_model is not None: - ext_model *= cmodel(x_mod) - else: - ext_model = cmodel(x_mod) - ax_att.plot(x_mod, ext_model, "k--", alpha=0.5) - ax_att.set_ylabel("Attenuation") - ax_att.set_ylim(0, 1.1) - - # Define legend lines - Leg_lines = [ - mpl.lines.Line2D([0], [0], color="k", linestyle="--", lw=2), - mpl.lines.Line2D([0], [0], color="#FE6100", lw=2), - mpl.lines.Line2D([0], [0], color="#648FFF", lw=2, alpha=0.5), - mpl.lines.Line2D([0], [0], color="#DC267F", lw=2, alpha=0.5), - mpl.lines.Line2D([0], [0], color="#785EF0", lw=2, alpha=1), - mpl.lines.Line2D([0], [0], color="#FFB000", lw=2, alpha=0.5), - ] - - # create the continum compound model (base for plotting lines) - cont_components = [] - - for cmodel in model: - if isinstance(cmodel, BlackBody1D): - cont_components.append(cmodel) - # plot as we go - ax.plot(x_mod, - cmodel(x_mod) * ext_model / x_mod, - "#FFB000", - alpha=0.5) - cont_model = cont_components[0] - for cmodel in cont_components[1:]: - cont_model += cmodel - cont_y = cont_model(x_mod) - - # now plot the dust bands and lines - for cmodel in model: - if isinstance(cmodel, Gaussian1D): - ax.plot( - x_mod, - (cont_y + cmodel(x_mod)) * ext_model / x_mod, - "#DC267F", - alpha=0.5, - ) - if isinstance(cmodel, Drude1D): - ax.plot( - x_mod, - (cont_y + cmodel(x_mod)) * ext_model / x_mod, - "#648FFF", - alpha=0.5, - ) - - ax.plot(x_mod, cont_y * ext_model / x_mod, "#785EF0", alpha=1) - - ax.plot(x_mod, model(x_mod) / x_mod, "#FE6100", alpha=1) - ax.errorbar( - x, - y / x, - yerr=yerr / x, - fmt="o", - markeredgecolor="k", - markerfacecolor="none", - ecolor="k", - elinewidth=0.2, - capsize=0.5, - markersize=6, - ) - - ax.set_ylim(0) - ax.set_ylabel(r"$\nu F_{\nu}$") - - ax.legend( - Leg_lines, - [ - "S07_attenuation", - "Spectrum Fit", - "Dust Features", - r"Atomic and $H_2$ Lines", - "Total Continuum Emissions", - "Continuum Components", - ], - prop={"size": 10}, - loc="best", - facecolor="white", - framealpha=1, - ncol=3, - ) - - # residuals, lower sub-figure - res = (y - model(x)) / x - std = np.std(res) - ax = axs[1] - - ax.set_yscale("linear") - ax.set_xscale("log") - ax.tick_params(axis="both", - which="major", - top="on", - right="on", - direction="in", - length=10) - ax.tick_params(axis="both", - which="minor", - top="on", - right="on", - direction="in", - length=5) - ax.minorticks_on() - - # Custom X axis ticks - ax.xaxis.set_ticks( - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 20, 25, 30, 40]) - - ax.axhline(0, linestyle="--", color="gray", zorder=0) - ax.plot(x, res, "ko-", fillstyle="none", zorder=1) - ax.set_ylim(-scalefac_resid * std, scalefac_resid * std) - ax.set_xlim(np.floor(np.amin(x)), np.ceil(np.amax(x))) - ax.set_xlabel(r"$\lambda$ [$\mu m$]") - ax.set_ylabel("Residuals [%]") - - # scalar x-axis marks - ax.xaxis.set_minor_formatter(mpl.ticker.ScalarFormatter()) - ax.xaxis.set_major_formatter(mpl.ticker.ScalarFormatter()) - - @staticmethod def update_dictionary(feature_dict, instrumentname, update_fwhms=False, redshift=0): """ Update parameter dictionary based on the instrument name. diff --git a/pahfit/model.py b/pahfit/model.py index 69d22f3..1dd15a5 100644 --- a/pahfit/model.py +++ b/pahfit/model.py @@ -2,17 +2,18 @@ from astropy import units as u import copy from astropy.modeling.fitting import LevMarLSQFitter +import matplotlib as mpl from matplotlib import pyplot as plt import numpy as np from scipy import interpolate, integrate +from pahfit import units from pahfit.features.util import bounded_is_fixed from pahfit.features import Features from pahfit.base import PAHFITBase from pahfit import instrument from pahfit.errors import PAHFITModelError -from pahfit.component_models import BlackBody1D -from pahfit import units +from pahfit.component_models import BlackBody1D, S07_attenuation class Model: @@ -411,41 +412,239 @@ def info(self): """Print out the last fit results.""" print(self.astropy_result) - def plot(self, spec=None, redshift=None): + def plot( + self, + spec=None, + redshift=None, + use_instrument_fwhm=False, + label_lines=False, + scalefac_resid=2, + **errorbar_kwargs, + ): """Plot model, and optionally compare to observational data. Parameters ---------- spec : Spectrum1D - Observational data. Does not have to be the same data that - was used for guessing or fitting. + Observational data. The units should be compatible with the + data that were used for the fit, but it does not have to be + the exact same spectrum. The spectrum will be converted to + internal units before plotting. redshift : float Redshift used to shift from the physical model, to the - observed model. + observed model. If None, it will be taken from spec.redshift - If None, will be taken from spec.redshift + use_instrument_fwhm : bool + For the lines, the default is to use the fwhm values + contained in the Features table. When set to True, the fwhm + will be determined by the instrument model instead. + + label_lines : bool + Add labels with the names of the lines, at the position of + each line. + + scalefac_resid : float + Factor multiplying the standard deviation of the residuals + to adjust plot limits. + + errorbar_kwargs : dict + Customize the data points plot by passing the given keyword + arguments to matplotlib.pyplot.errorbar. """ inst, z = self._parse_instrument_and_redshift(spec, redshift) _, _, _, xz, yz, uncz = self._convert_spec_data(spec, z) - # Always use the current FWHM here (use_instrument_fwhm would - # overwrite the fitted value in the instrument overlap regions!) - astropy_model = self._construct_astropy_model( - inst, z, use_instrument_fwhm=False - ) - - enough_samples = max(1000, len(spec.wavelength)) + enough_samples = max(10000, len(spec.wavelength)) + x_mod = np.logspace(np.log10(min(xz)), np.log10(max(xz)), enough_samples) fig, axs = plt.subplots( ncols=1, nrows=2, - figsize=(15, 10), + figsize=(10, 10), gridspec_kw={"height_ratios": [3, 1]}, sharex=True, ) - PAHFITBase.plot(axs, xz, yz, uncz, astropy_model, model_samples=enough_samples) + + # spectrum and best fit model + ax = axs[0] + ax.set_yscale("linear") + ax.set_xscale("log") + ax.minorticks_on() + ax.tick_params( + axis="both", which="major", top="on", right="on", direction="in", length=10 + ) + ax.tick_params( + axis="both", which="minor", top="on", right="on", direction="in", length=5 + ) + + ext_model = None + has_att = "attenuation" in self.features["kind"] + has_abs = "absorption" in self.features["kind"] + if has_att: + row = self.features[self.features["kind"] == "attenuation"][0] + tau = row["tau"][0] + ext_model = S07_attenuation(tau_sil=tau)(x_mod) + + if has_abs: + raise NotImplementedError( + "plotting absorption features not implemented yet" + ) + + if has_att or has_abs: + ax_att = ax.twinx() # axis for plotting the extinction curve + ax_att.tick_params(which="minor", direction="in", length=5) + ax_att.tick_params(which="major", direction="in", length=10) + ax_att.minorticks_on() + ax_att.plot(x_mod, ext_model, "k--", alpha=0.5) + ax_att.set_ylabel("Attenuation") + ax_att.set_ylim(0, 1.1) + else: + ext_model = np.ones(len(x_mod)) + + # Define legend lines + Leg_lines = [ + mpl.lines.Line2D([0], [0], color="k", linestyle="--", lw=2), + mpl.lines.Line2D([0], [0], color="#FE6100", lw=2), + mpl.lines.Line2D([0], [0], color="#648FFF", lw=2, alpha=0.5), + mpl.lines.Line2D([0], [0], color="#DC267F", lw=2, alpha=0.5), + mpl.lines.Line2D([0], [0], color="#785EF0", lw=2, alpha=1), + mpl.lines.Line2D([0], [0], color="#FFB000", lw=2, alpha=0.5), + ] + + # local utility + def tabulate_components(kind): + ss = {} + for name in self.features[self.features["kind"] == kind]["name"]: + ss[name] = self.tabulate(inst, z, x_mod, self.features["name"] == name) + return {name: s.flux.value for name, s in ss.items()} + + cont_y = np.zeros(len(x_mod)) + if "dust_continuum" in self.features["kind"]: + # one plot for every component + for y in tabulate_components("dust_continuum").values(): + ax.plot(x_mod, y * ext_model, "#FFB000", alpha=0.5) + # keep track of total continuum + cont_y += y + + if "starlight" in self.features["kind"]: + star_y = self.tabulate( + inst, z, x_mod, self.features["kind"] == "starlight" + ).flux.value + ax.plot(x_mod, star_y * ext_model, "#ffB000", alpha=0.5) + cont_y += star_y + + # total continuum + ax.plot(x_mod, cont_y * ext_model, "#785EF0", alpha=1) + + # now plot the dust bands and lines + if "dust_feature" in self.features["kind"]: + for y in tabulate_components("dust_feature").values(): + ax.plot( + x_mod, + (cont_y + y) * ext_model, + "#648FFF", + alpha=0.5, + ) + + if "line" in self.features["kind"]: + for name, y in tabulate_components("line").items(): + ax.plot( + x_mod, + (cont_y + y) * ext_model, + "#DC267F", + alpha=0.5, + ) + if label_lines: + i = np.argmax(y) + # ignore out of range lines + if i > 0 and i < len(y) - 1: + w = x_mod[i] + ax.text( + w, + y[i], + name, + va="center", + ha="center", + rotation="vertical", + bbox=dict(facecolor="white", alpha=0.75, pad=0), + ) + + ax.plot(x_mod, self.tabulate(inst, z, x_mod).flux.value, "#FE6100", alpha=1) + + # data + default_kwargs = dict( + fmt="o", + markeredgecolor="k", + markerfacecolor="none", + ecolor="k", + elinewidth=0.2, + capsize=0.5, + markersize=6, + ) + + ax.errorbar(xz, yz, yerr=uncz, **(default_kwargs | errorbar_kwargs)) + + ax.set_ylim(0) + ax.set_ylabel(r"$\nu F_{\nu}$") + + ax.legend( + Leg_lines, + [ + "S07_attenuation", + "Spectrum Fit", + "Dust Features", + r"Atomic and $H_2$ Lines", + "Total Continuum Emissions", + "Continuum Components", + ], + prop={"size": 10}, + loc="best", + facecolor="white", + framealpha=1, + ncol=3, + ) + + # residuals = data in rest frame - (model evaluated at rest frame wavelengths) + res = yz - self.tabulate(inst, 0, xz).flux.value + std = np.nanstd(res) + ax = axs[1] + + ax.set_yscale("linear") + ax.set_xscale("log") + ax.tick_params( + axis="both", which="major", top="on", right="on", direction="in", length=10 + ) + ax.tick_params( + axis="both", which="minor", top="on", right="on", direction="in", length=5 + ) + ax.minorticks_on() + + # Custom X axis ticks + ax.xaxis.set_ticks( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 20, 25, 30, 40] + ) + + ax.axhline(0, linestyle="--", color="gray", zorder=0) + ax.plot( + xz, + res, + "ko", + fillstyle="none", + zorder=1, + markersize=errorbar_kwargs.get("markersize", None), + alpha=errorbar_kwargs.get("alpha", None), + linestyle="none", + ) + ax.set_ylim(-scalefac_resid * std, scalefac_resid * std) + ax.set_xlim(np.floor(np.amin(xz)), np.ceil(np.amax(xz))) + ax.set_xlabel(r"$\lambda$ [$\mu m$]") + ax.set_ylabel("Residuals [%]") + + # scalar x-axis marks + ax.xaxis.set_major_formatter(mpl.ticker.ScalarFormatter()) fig.subplots_adjust(hspace=0) + return fig def copy(self): """Copy the model. diff --git a/pahfit/scripts/plot_pahfit.py b/pahfit/scripts/plot_pahfit.py index d5bb853..a876505 100755 --- a/pahfit/scripts/plot_pahfit.py +++ b/pahfit/scripts/plot_pahfit.py @@ -6,11 +6,8 @@ import matplotlib as mpl from pahfit.model import Model -from pahfit.base import PAHFITBase from pahfit.helpers import read_spectrum -from astropy import units as u - def initialize_parser(): """ @@ -78,14 +75,6 @@ def main(): def default_layout_plot(spec, model, scalefac_resid): - """ - Returns - ------- - fig : Figure object - - """ - - # plot result fontsize = 18 font = {"size": fontsize} mpl.rc("font", **font) @@ -96,29 +85,7 @@ def default_layout_plot(spec, model, scalefac_resid): mpl.rc("xtick.minor", size=3, width=1) mpl.rc("ytick.minor", size=3, width=1) - fig, axs = plt.subplots( - ncols=1, - nrows=2, - figsize=(15, 10), - gridspec_kw={"height_ratios": [3, 1]}, - sharex=True, - ) - - enough_samples = max(1000, len(spec.wavelength)) - - PAHFITBase.plot( - axs, - spec.wavelength.to(u.micron).value, - spec.flux, - spec.uncertainty.array, - model._construct_astropy_model( - spec.meta["instrument"], spec.redshift, use_instrument_fwhm=False - ), - model_samples=enough_samples, - scalefac_resid=scalefac_resid, - ) - - # use the whitespace better + fig = model.plot(spec, scalefac_resid=scalefac_resid) fig.subplots_adjust(hspace=0) return fig