Skip to content

Commit

Permalink
Merge pull request #281 from drvdputt/plot_rewrite
Browse files Browse the repository at this point in the history
Rewrite of Model.plot()
  • Loading branch information
jdtsmith authored May 13, 2024
2 parents 6ec9596 + 50f232c commit 68ce437
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 241 deletions.
191 changes: 0 additions & 191 deletions pahfit/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
import matplotlib as mpl

from pahfit.instrument import within_segment, fwhm
from pahfit.errors import PAHFITModelError
Expand Down Expand Up @@ -265,196 +264,6 @@ def model_from_param_info(param_info):

return model

@staticmethod
def plot(axs, x, y, yerr, model, model_samples=1000, scalefac_resid=2):
"""
Plot model using axis object.
Parameters
----------
axs : matplotlib.axis objects
where to put the plot
x : floats
wavelength points
y : floats
observed spectrum
yerr: floats
observed spectrum uncertainties
model : PAHFITBase model (astropy modeling CompoundModel)
model giving all the components and parameters
model_samples : int
Total number of wavelength points to allocate to the model display
scalefac_resid : float
Factor multiplying the standard deviation of the residuals to adjust plot limits
"""
# remove units if they are present
if hasattr(x, "value"):
x = x.value
if hasattr(y, "value"):
y = y.value
if hasattr(yerr, "value"):
yerr = yerr.value

# Fine x samples for model fit
x_mod = np.logspace(np.log10(min(x)), np.log10(max(x)), model_samples)

# spectrum and best fit model
ax = axs[0]
ax.set_yscale("linear")
ax.set_xscale("log")
ax.minorticks_on()
ax.tick_params(axis="both",
which="major",
top="on",
right="on",
direction="in",
length=10)
ax.tick_params(axis="both",
which="minor",
top="on",
right="on",
direction="in",
length=5)

ax_att = ax.twinx() # axis for plotting the extinction curve
ax_att.tick_params(which="minor", direction="in", length=5)
ax_att.tick_params(which="major", direction="in", length=10)
ax_att.minorticks_on()

# get the extinction model (probably a better way to do this)
ext_model = None
for cmodel in model:
if isinstance(cmodel, S07_attenuation):
ext_model = cmodel(x_mod)

# get additional extinction components that can be
# characterized by functional forms (Drude profile in this case)
for cmodel in model:
if isinstance(cmodel, att_Drude1D):
if ext_model is not None:
ext_model *= cmodel(x_mod)
else:
ext_model = cmodel(x_mod)
ax_att.plot(x_mod, ext_model, "k--", alpha=0.5)
ax_att.set_ylabel("Attenuation")
ax_att.set_ylim(0, 1.1)

# Define legend lines
Leg_lines = [
mpl.lines.Line2D([0], [0], color="k", linestyle="--", lw=2),
mpl.lines.Line2D([0], [0], color="#FE6100", lw=2),
mpl.lines.Line2D([0], [0], color="#648FFF", lw=2, alpha=0.5),
mpl.lines.Line2D([0], [0], color="#DC267F", lw=2, alpha=0.5),
mpl.lines.Line2D([0], [0], color="#785EF0", lw=2, alpha=1),
mpl.lines.Line2D([0], [0], color="#FFB000", lw=2, alpha=0.5),
]

# create the continum compound model (base for plotting lines)
cont_components = []

for cmodel in model:
if isinstance(cmodel, BlackBody1D):
cont_components.append(cmodel)
# plot as we go
ax.plot(x_mod,
cmodel(x_mod) * ext_model / x_mod,
"#FFB000",
alpha=0.5)
cont_model = cont_components[0]
for cmodel in cont_components[1:]:
cont_model += cmodel
cont_y = cont_model(x_mod)

# now plot the dust bands and lines
for cmodel in model:
if isinstance(cmodel, Gaussian1D):
ax.plot(
x_mod,
(cont_y + cmodel(x_mod)) * ext_model / x_mod,
"#DC267F",
alpha=0.5,
)
if isinstance(cmodel, Drude1D):
ax.plot(
x_mod,
(cont_y + cmodel(x_mod)) * ext_model / x_mod,
"#648FFF",
alpha=0.5,
)

ax.plot(x_mod, cont_y * ext_model / x_mod, "#785EF0", alpha=1)

ax.plot(x_mod, model(x_mod) / x_mod, "#FE6100", alpha=1)
ax.errorbar(
x,
y / x,
yerr=yerr / x,
fmt="o",
markeredgecolor="k",
markerfacecolor="none",
ecolor="k",
elinewidth=0.2,
capsize=0.5,
markersize=6,
)

ax.set_ylim(0)
ax.set_ylabel(r"$\nu F_{\nu}$")

ax.legend(
Leg_lines,
[
"S07_attenuation",
"Spectrum Fit",
"Dust Features",
r"Atomic and $H_2$ Lines",
"Total Continuum Emissions",
"Continuum Components",
],
prop={"size": 10},
loc="best",
facecolor="white",
framealpha=1,
ncol=3,
)

# residuals, lower sub-figure
res = (y - model(x)) / x
std = np.std(res)
ax = axs[1]

ax.set_yscale("linear")
ax.set_xscale("log")
ax.tick_params(axis="both",
which="major",
top="on",
right="on",
direction="in",
length=10)
ax.tick_params(axis="both",
which="minor",
top="on",
right="on",
direction="in",
length=5)
ax.minorticks_on()

# Custom X axis ticks
ax.xaxis.set_ticks(
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 20, 25, 30, 40])

ax.axhline(0, linestyle="--", color="gray", zorder=0)
ax.plot(x, res, "ko-", fillstyle="none", zorder=1)
ax.set_ylim(-scalefac_resid * std, scalefac_resid * std)
ax.set_xlim(np.floor(np.amin(x)), np.ceil(np.amax(x)))
ax.set_xlabel(r"$\lambda$ [$\mu m$]")
ax.set_ylabel("Residuals [%]")

# scalar x-axis marks
ax.xaxis.set_minor_formatter(mpl.ticker.ScalarFormatter())
ax.xaxis.set_major_formatter(mpl.ticker.ScalarFormatter())

@staticmethod
def update_dictionary(feature_dict, instrumentname, update_fwhms=False, redshift=0):
"""
Update parameter dictionary based on the instrument name.
Expand Down
Loading

0 comments on commit 68ce437

Please sign in to comment.