Skip to content

Commit

Permalink
move common prediction code to SpecModel method. Add basic polycal te…
Browse files Browse the repository at this point in the history
…st; fix imports; fix speccal bugs in SpecModel.
  • Loading branch information
bd-j committed Aug 8, 2024
1 parent eb4a2b0 commit 768c689
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 25 deletions.
9 changes: 4 additions & 5 deletions prospect/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
"""


from .sedmodel import ProspectorParams, SpecModel
from .sedmodel import PolySpecModel, SplineSpecModel
from .sedmodel import AGNSpecModel
from .parameters import ProspectorParams
from .sedmodel import SpecModel, HyperSpecModel, AGNSpecModel


__all__ = ["ProspectorParams",
"SpecModel",
"PolySpecModel", "SplineSpecModel",
"LineSpecModel", "AGNSpecModel"
"HyperSpecModel",
"AGNSpecModel"
]

48 changes: 34 additions & 14 deletions prospect/models/sedmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@


__all__ = ["SpecModel",
"PolySpecModel", "SplineSpecModel",
"HyperSpecModel", "HyperPolySpecModel",
"AGNSpecModel",
"PolyFitModel"]
"HyperSpecModel",
"AGNSpecModel"]


class SpecModel(ProspectorParams):
Expand Down Expand Up @@ -103,7 +101,28 @@ def predict(self, theta, observations=None, sps=None, **extras):
will be `mfrac` the ratio of the surviving stellar mass to the
stellar mass formed.
"""
self.predict_init(theta, sps)

# generate predictions for likelihood
# this assumes all spectral datasets (if present) occur first
# because they can change the line strengths during marginalization.
predictions = [self.predict_obs(obs) for obs in observations]

return predictions, self._mfrac

def predict_init(self, theta, sps):
"""Generate the physical model on the model wavelength grid, and cache
many quantities used in common for all kinds of predictions.
Parameters
----------
theta : ndarray of shape ``(ndim,)``
Vector of free model parameter values.
sps :
An `sps` object to be used in the model generation. It must have
the :py:func:`get_galaxy_spectrum` method defined.
"""
# generate and cache intrinsic model spectrum and info
self.set_parameters(theta)
self._wave, self._spec, self._mfrac = sps.get_galaxy_spectrum(**self.params)
Expand Down Expand Up @@ -131,13 +150,6 @@ def predict(self, theta, observations=None, sps=None, **extras):
self._smooth_spec = self.add_dla(self._wave, self._smooth_spec)
self._smooth_spec = self.add_damping_wing(self._wave, self._smooth_spec)

# generate predictions for likelihood
# this assumes all spectral datasets (if present) occur first
# because they can change the line strengths during marginalization.
predictions = [self.predict_obs(obs) for obs in observations]

return predictions, self._mfrac

def predict_obs(self, obs):
if obs.kind == "spectrum":
prediction = self.predict_spec(obs)
Expand Down Expand Up @@ -253,6 +265,7 @@ def predict_spec(self, obs):
obs_wave = self.observed_wave(self._wave, do_wavecal=False)

# get output wavelength vector
# TODO: remove this and require all Spectrum instances to have a wavelength array
self._outwave = obs.wavelength
if self._outwave is None:
self._outwave = obs_wave
Expand Down Expand Up @@ -285,15 +298,19 @@ def predict_spec(self, obs):
inst_spec[emask] += self._fix_eline_spec.sum(axis=1)

# --- (de-) apply calibration ---
extra_mask = self._fit_eline_pixelmask
if not extra_mask.any():
extra_mask = True # all pixels are ok
response = obs.compute_response(spec=inst_spec,
extra_mask=self._fit_eline_pixelmask,
extra_mask=extra_mask,
**self.params)
inst_spec = inst_spec * response

# --- fit and add lines if necessary ---
emask = self._fit_eline_pixelmask
if emask.any():
# We need the spectroscopic covariance matrix to do emission line optimization and marginalization
# We need the spectroscopic covariance matrix to do emission line
# optimization and marginalization
spec_unc = None
# FIXME: do this only if the noise model is non-trivial, and make sure masking is consistent
#vectors = obs.noise.populate_vectors(obs)
Expand All @@ -302,7 +319,8 @@ def predict_spec(self, obs):
inst_spec[emask] += self._fit_eline_spec.sum(axis=1)

# --- cache intrinsic spectrum for this observation ---
self._sed.append(inst_spec / response)
self._sed = inst_spec / response
self._speccal = response

return inst_spec

Expand Down Expand Up @@ -635,6 +653,7 @@ def fit_mle_elines(self, obs, calibrated_spec, sigma_spec=None):

# generate line amplitudes in observed flux units
units_factor = self.flux_norm() / (1 + self._zred)
# FIXME: use obs.response instead of _speccal, remove all references to speccal
calib_factor = np.interp(self._ewave_obs[idx], nebwave, self._speccal[emask])
linecal = units_factor * calib_factor
alpha_breve = self._eline_lum[idx] * linecal
Expand Down Expand Up @@ -959,6 +978,7 @@ def predict_spec(self, obs):

# --- cache intrinsic spectrum ---
self._sed = inst_spec / response
self._speccal = response

return inst_spec

Expand Down
5 changes: 4 additions & 1 deletion prospect/observation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# -*- coding: utf-8 -*-

from .observation import Observation
from .observation import Photometry, Spectrum, Lines, UndersampledSpectrum, IntrinsicSpectrum
from .observation import Photometry, Spectrum, Lines
from .observation import UndersampledSpectrum, IntrinsicSpectrum
from .observation import PolyOptCal, SplineOptCal
from .observation import from_oldstyle, from_serial

__all__ = ["Observation",
"Photometry", "Spectrum", "Lines",
"UndersampledSpectrum", "InstrinsicSpectrum",
"PolyOptCal", "SplineOptCal",
"from_oldstyle", "from_serial"]
10 changes: 5 additions & 5 deletions prospect/observation/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,6 @@ def instrumental_response(self, **extras):
return 1.0



class Lines(Spectrum):

_kind = "lines"
Expand Down Expand Up @@ -532,10 +531,10 @@ def __init__(self, *args,
polynomial_regularization=0,
median_polynomial=0,
**kwargs):
super(PolyOptCal, self).__init(*args, **kwargs)
super(PolyOptCal, self).__init__(*args, **kwargs)
self.polynomial_order = polynomial_order
self.polynomial_regularization = polynomial_regularization
self.median_molynomial = median_polynomial
self.median_polynomial = median_polynomial

def _available_parameters(self):
# These should both be attached to the Observation instance as attributes
Expand Down Expand Up @@ -573,7 +572,8 @@ def compute_response(self, spec=None, extra_mask=True, **kwargs):
assert (self.mask.sum() > order), f"Not enough points to constrain polynomial of order {order}"

polyopt = (order > 0)
if ~polyopt:
if (not polyopt):
print("no polynomial")
self.response = np.ones_like(self.wavelength)
return self.response

Expand Down Expand Up @@ -614,7 +614,7 @@ def __init__(self, *args,
spline_knot_spacing=None,
spline_knot_n=None,
**kwargs):
super(SplineOptCal, self).__init(*args, **kwargs)
super(SplineOptCal, self).__init__(*args, **kwargs)

self.params = {}
if spline_knot_wave is not None:
Expand Down
80 changes: 80 additions & 0 deletions tests/test_polycal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import pytest

from prospect.sources import CSPSpecBasis
from prospect.models import SpecModel, templates
from prospect.observation import Spectrum, Photometry, PolyOptCal


class PolySpectrum(PolyOptCal, Spectrum):
pass



@pytest.fixture
def get_sps():
sps = CSPSpecBasis(zcontinuous=1)
return sps


def build_model(add_neb=False):
model_params = templates.TemplateLibrary["parametric_sfh"]
if add_neb:
model_params.update(templates.TemplateLibrary["nebular"])
return SpecModel(model_params)


def build_obs(multispec=False):
N = 1500 * (2 - multispec)
wmax = 7000
wsplit = wmax - N * multispec

fnames = list([f"sdss_{b}0" for b in "ugriz"])
Nf = len(fnames)
phot = [Photometry(filters=fnames,
flux=np.ones(Nf),
uncertainty=np.ones(Nf)/10)]
spec = [PolySpectrum(wavelength=np.linspace(4000, wsplit, N),
flux=np.ones(N),
uncertainty=np.ones(N) / 10,
mask=slice(None),
polynomial_order=5)
]

if multispec:
spec += [Spectrum(wavelength=np.linspace(wsplit+1, wmax, N),
flux=np.ones(N), uncertainty=np.ones(N) / 10,
mask=slice(None))]

obslist = spec + phot
[obs.rectify() for obs in obslist]
return obslist


def test_polycal(plot=False):
"""Make sure the polynomial optimization works
"""
sps = get_sps
observations = build_obs()
model = build_model()

preds, extra = model.predict(model.theta, observations=observations, sps=sps)
obs = observations[0]

assert np.any(obs.response != 0)

if plot:
import matplotlib.pyplot as pl
fig, axes = pl.subplots(3, 1, sharex=True)
ax = axes[0]
ax.plot(obs.wavelength, obs.flux, label="obseved flux (ones)")
ax.plot(obs.wavelength, preds[0], label="model flux (times response)")
ax = axes[1]
ax.plot(obs.wavelength, obs.response, label="instrumental response (polynomial)")
ax = axes[2]
ax.plot(obs.wavelength, preds[0]/ obs.response, label="intrinsic model spectrum")
ax.set_xlabel("wavelength")
[ax.legend() for ax in axes]

0 comments on commit 768c689

Please sign in to comment.