diff --git a/beast/observationmodel/observations.py b/beast/observationmodel/observations.py index 30e09a433..d2b12ae4a 100644 --- a/beast/observationmodel/observations.py +++ b/beast/observationmodel/observations.py @@ -8,7 +8,7 @@ from beast.observationmodel.vega import Vega from beast.physicsmodel.priormodel import PriorAgeModel, PriorMassModel -from beast.physicsmodel.grid_weights_stars import compute_bin_boundaries +from beast.physicsmodel.grid_weights import compute_bin_boundaries __all__ = ["Observations", "gen_SimObs_from_sedgrid"] diff --git a/beast/physicsmodel/creategrid.py b/beast/physicsmodel/creategrid.py index 3c28a5704..44d47c7a0 100644 --- a/beast/physicsmodel/creategrid.py +++ b/beast/physicsmodel/creategrid.py @@ -23,7 +23,9 @@ from beast.physicsmodel.stars import stellib from beast.physicsmodel.grid import SpectralGrid, SEDGrid -from beast.physicsmodel.grid_and_prior_weights import compute_av_rv_fA_prior_weights +from beast.physicsmodel.priormodel import PriorDustModel + +from beast.physicsmodel.grid_weights import compute_grid_weights from astropy.table import Table from beast.tools.helpers import generator @@ -321,7 +323,7 @@ def make_extinguished_grid( # Create the sampling mesh # ======================== # basically the dot product from all input 1d vectors - # setup integration over the full dust parameter grid + # setup interation over the full dust parameter grid if with_fA: @@ -425,15 +427,25 @@ def make_extinguished_grid( cols["Rv"][N0 * count : N0 * (count + 1)] = Rv # compute the dust weights - dust_prior_weight = compute_av_rv_fA_prior_weights( - Av, - Rv, - f_A, - g0.grid["distance"].data, - av_prior_model=av_prior_model, - rv_prior_model=rv_prior_model, - fA_prior_model=fA_prior_model, - ) + # moved here in 2023 to support distance based dust priors + dists = g0.grid["distance"].data + if av_prior_model["name"] == "step": + av_prior_weights = av_prior(np.full((len(dists)), Av), y=dists) + else: + av_prior_weights = av_prior(Av) + if rv_prior_model["name"] == "step": + rv_prior_weights = rv_prior(np.full((len(dists)), Rv), y=dists) + else: + rv_prior_weights = rv_prior(Rv) + if fA_prior_model["name"] == "step": + f_A_prior_weights = fA_prior(np.full((len(dists)), f_A), y=dists) + else: + if with_fA: + f_A_prior_weights = fA_prior(f_A) + else: + f_A_prior_weights = 1.0 + + dust_prior_weight = av_prior_weights * rv_prior_weights * f_A_prior_weights # get new attributes if exist for key in list(temp_results.grid.keys()): @@ -468,6 +480,25 @@ def make_extinguished_grid( _lamb = cols.pop("lamb") + # now add the grid weights + av_grid_weights = compute_grid_weights(avs) + for cav, cav_gweight in zip(avs, av_grid_weights): + gvals = cols["Av"] == cav + cols["weight"][gvals] *= cav_gweight + cols["grid_weight"][gvals] *= cav_gweight + + rv_grid_weights = compute_grid_weights(rvs) + for rav, rav_gweight in zip(rvs, rv_grid_weights): + gvals = cols["Rv"] == rav + cols["weight"][gvals] *= rav_gweight + cols["grid_weight"][gvals] *= rav_gweight + + fA_grid_weights = compute_grid_weights(fAs) + for cfA, cfA_gweight in zip(fAs, fA_grid_weights): + gvals = cols["f_A"] == cfA + cols["weight"][gvals] *= cfA_gweight + cols["grid_weight"][gvals] *= cfA_gweight + # free the memory of temp_results # del temp_results # del tempgrid diff --git a/beast/physicsmodel/grid_and_prior_weights.py b/beast/physicsmodel/grid_and_prior_weights.py index dd056daaf..e5ce8cd11 100644 --- a/beast/physicsmodel/grid_and_prior_weights.py +++ b/beast/physicsmodel/grid_and_prior_weights.py @@ -15,10 +15,7 @@ import numpy as np -from beast.physicsmodel.grid_weights_stars import compute_distance_grid_weights -from beast.physicsmodel.grid_weights_stars import compute_age_grid_weights -from beast.physicsmodel.grid_weights_stars import compute_mass_grid_weights -from beast.physicsmodel.grid_weights_stars import compute_metallicity_grid_weights +from beast.physicsmodel.grid_weights import compute_grid_weights from beast.physicsmodel.priormodel import ( PriorAgeModel, @@ -87,7 +84,7 @@ def compute_distance_age_mass_metallicity_weights( if n_dist > 1: # get the distance weights - dist_grid_weights = compute_distance_grid_weights(uniq_dists) + dist_grid_weights = compute_grid_weights(uniq_dists) dist_grid_weights /= np.sum(dist_grid_weights) dist_prior = PriorDistanceModel(distance_prior_model) dist_prior_weights = dist_prior(uniq_dists) @@ -157,7 +154,7 @@ def compute_age_mass_metallicity_weights( uniq_ages = np.unique(_tgrid[zindxs]["logA"]) # compute the age weights - age_grid_weights = compute_age_grid_weights(uniq_ages) + age_grid_weights = compute_grid_weights(uniq_ages, log=True) if isinstance(age_prior_model, dict): age_prior = PriorAgeModel(age_prior_model) else: @@ -187,7 +184,7 @@ def compute_age_mass_metallicity_weights( cur_masses = np.unique(_tgrid_single_age["M_ini"]) n_masses = len(_tgrid_single_age["M_ini"]) if len(cur_masses) < n_masses: - umass_grid_weights = compute_mass_grid_weights(cur_masses) + umass_grid_weights = compute_grid_weights(cur_masses) umass_prior_weights = mass_prior(cur_masses) mass_grid_weights = np.zeros(n_masses, dtype=float) mass_prior_weights = np.zeros(n_masses, dtype=float) @@ -197,7 +194,7 @@ def compute_age_mass_metallicity_weights( mass_prior_weights[gvals] = umass_prior_weights[k] else: cur_masses = _tgrid_single_age["M_ini"] - mass_grid_weights = compute_mass_grid_weights(cur_masses) + mass_grid_weights = compute_grid_weights(cur_masses) mass_prior_weights = mass_prior(cur_masses) else: @@ -222,7 +219,7 @@ def compute_age_mass_metallicity_weights( # ensure that the metallicity prior is uniform if len(uniq_Zs) > 1: # get the metallicity weights - met_grid_weights = compute_metallicity_grid_weights(uniq_Zs) + met_grid_weights = compute_grid_weights(uniq_Zs) met_grid_weights /= np.sum(met_grid_weights) met_prior = PriorMetallicityModel(met_prior_model) met_prior_weights = met_prior(uniq_Zs) diff --git a/beast/physicsmodel/grid_weights.py b/beast/physicsmodel/grid_weights.py new file mode 100644 index 000000000..074a14d90 --- /dev/null +++ b/beast/physicsmodel/grid_weights.py @@ -0,0 +1,77 @@ +import numpy as np + +__all__ = ["compute_grid_weights", "compute_bin_boundaries"] + + +def compute_grid_weights(in_x, log=False): + """ + Compute the grid weights. Needed for marginalization (aka integration). The + weights are the relative widths of of each x bin. + + Parameters + ---------- + x : numpy array + centers of each bin + + log : boolean + set if values are in log units + + Returns + ------- + weights : numpy array + weights as bin widths divided by the average width + """ + # ensure x values are monotonically increasing + sindxs = np.argsort(in_x) + x = in_x[sindxs] + + n_x = len(x) + bin_hdiffs = np.diff(x) / 2.0 + + # define the bin min and max boundaries + # handling the two edge cases + bin_mins = np.zeros(n_x) + bin_mins[1:] = x[1:] - bin_hdiffs + bin_mins[0] = x[0] - bin_hdiffs[0] + + bin_maxs = np.zeros(n_x) + bin_maxs[0:-1] = x[0:-1] + bin_hdiffs + bin_maxs[-1] = x[-1] + bin_hdiffs[-1] + + if log: + weights = (10**bin_maxs) - (10**bin_mins) + else: + weights = bin_maxs - bin_mins + + # put the weights in the same order as in_x + out_weights = np.zeros(n_x) + out_weights[sindxs] = weights + + # return normalized weights to avoid numerical issues + return out_weights / np.average(out_weights) + + +def compute_bin_boundaries(tab): + """ + Computes the boundaries of bins + + The bin boundaries are defined as the midpoint between each value in tab. + At the two edges, 1/2 of the bin width is subtracted/added to the + min/max of tab. + + Parameters + ---------- + tab : numpy array + centers of each bin + + Returns + ------- + tab2 : numpy array + boundaries of the bins + """ + temp = tab[1:] - np.diff(tab) / 2.0 + tab2 = np.zeros(len(tab) + 1) + tab2[0] = tab[0] - np.diff(tab)[0] / 2.0 + tab2[-1] = tab[-1] + np.diff(tab)[-1] / 2.0 + tab2[1:-1] = temp + return tab2 \ No newline at end of file diff --git a/beast/physicsmodel/grid_weights_stars.py b/beast/physicsmodel/grid_weights_stars.py deleted file mode 100644 index 3193ca3e1..000000000 --- a/beast/physicsmodel/grid_weights_stars.py +++ /dev/null @@ -1,169 +0,0 @@ -""" -Grid Weights -============ -The use of a non-uniformly spaced grid complicates the marginalization -step as the trick of summation instead of integration is used. But this -trick only works when the grid is uniformly spaced in all dimensions. - -If the grid is not uniformly spaced, weights can be used to correct -for the non-uniform spacing. -""" -import numpy as np - -__all__ = [ - "compute_distance_grid_weights", - "compute_age_grid_weights", - "compute_mass_grid_weights", - "compute_metallicity_grid_weights", - "compute_bin_boundaries", -] - - -def compute_bin_boundaries(tab): - """ - Computes the boundaries of bins - - The bin boundaries are defined as the midpoint between each value in tab. - At the two edges, 1/2 of the bin width is subtractted/added to the - min/max of tab. - - Parameters - ---------- - tab : numpy array - centers of each bin - - Returns - ------- - tab2 : numpy array - boundaries of the bins - """ - temp = tab[1:] - np.diff(tab) / 2.0 - tab2 = np.zeros(len(tab) + 1) - tab2[0] = tab[0] - np.diff(tab)[0] / 2.0 - tab2[-1] = tab[-1] + np.diff(tab)[-1] / 2.0 - tab2[1:-1] = temp - return tab2 - - -def compute_age_grid_weights(logages): - """ - Computes the age weights to set a uniform prior on linear SFR - - Parameters - ---------- - logages : numpy vector - log(ages) - - Returns - ------- - age_weights : numpy vector - total masses at each age for a constant SFR in linear age - """ - # ages need to be monotonically increasing - aindxs = np.argsort(logages) - - # Computes the bin boundaries in log - logages_bounds = compute_bin_boundaries(logages[aindxs]) - - # initialize the age weights - age_weights = np.full(len(aindxs), 0.0) - - # Returns the age weight as a numpy array - age_weights[aindxs] = np.diff(10 ** (logages_bounds)) - - # normalize to avoid numerical issues (too small or too large) - age_weights /= np.average(age_weights) - - # return in the order that logages was passed - return age_weights - - -def compute_mass_grid_weights(masses): - """ - Computes the mass weights to set a uniform prior on linear mass - - Parameters - ---------- - masses : numpy vector - masses - - Returns - ------- - mass_weights : numpy vector - weights to provide a constant SFR in linear age - """ - # sort the initial mass along this isochrone - sindxs = np.argsort(masses) - - # Compute the mass bin boundaries - masses_bounds = compute_bin_boundaries(masses[sindxs]) - - # compute the weights = bin widths - mass_weights = np.zeros(len(masses)) - mass_weights[sindxs] = np.diff(masses_bounds) - - # normalize to avoid numerical issues (too small or too large) - mass_weights /= np.average(mass_weights) - - return mass_weights - - -def compute_metallicity_grid_weights(mets): - """ - Computes the metallicity weights to set a uniform prior on linear metallicity - - Parameters - ---------- - mets : numpy vector - metallicities - - Returns - ------- - metallicity_weights : numpy vector - weights to provide a flat metallicity - """ - # sort the initial mass along this isochrone - sindxs = np.argsort(mets) - - # Compute the mass bin boundaries - mets_bounds = compute_bin_boundaries(mets[sindxs]) - - # compute the weights = bin widths - mets_weights = np.zeros(len(mets)) - mets_weights[sindxs] = np.diff(mets_bounds) - - # normalize to avoid numerical issues (too small or too large) - mets_weights /= np.average(mets_weights) - - return mets_weights - - -def compute_distance_grid_weights(dists): - """ - Computes the distance weights to set a uniform prior on linear distance - - Parameters - ---------- - dists : numpy vector - distances - - Returns - ------- - dist_weights : numpy vector - weights to provide a flat distance - """ - # sort - tdists = np.array(dists) - sindxs = np.argsort(tdists) - - # Compute the bin boundaries - dists_bounds = compute_bin_boundaries(tdists[sindxs]) - - # compute the weights = bin widths - dists_weights = np.zeros(len(tdists)) - dists_weights[sindxs] = np.diff(dists_bounds) - - # normalize to avoid numerical issues (too small or too large) - dists_weights /= np.average(dists_weights) - - return dists_weights diff --git a/beast/physicsmodel/model_grid.py b/beast/physicsmodel/model_grid.py index 626c0f8ae..43461ab8e 100644 --- a/beast/physicsmodel/model_grid.py +++ b/beast/physicsmodel/model_grid.py @@ -435,7 +435,17 @@ def make_extinguished_sed_grid( """ # create the dust grid arrays - avs = np.arange(av[0], av[1] + 0.5 * av[2], av[2]) + if len(av) > 3: + # check if a log grid is requested + if av[3] == "log": + print("generating a log av grid") + avs = 10 ** np.arange(np.log10(av[0]), np.log10(av[1]), av[2]) + else: + print("generating a linear av grid") + avs = np.arange(av[0], av[1] + 0.5 * av[2], av[2]) + else: + print("generating a linear av grid") + avs = np.arange(av[0], av[1] + 0.5 * av[2], av[2]) rvs = np.arange(rv[0], rv[1] + 0.5 * rv[2], rv[2]) if fA is not None: fAs = np.arange(fA[0], fA[1] + 0.5 * fA[2], fA[2]) diff --git a/beast/physicsmodel/priormodel.py b/beast/physicsmodel/priormodel.py index 5c0fe5752..17574d830 100644 --- a/beast/physicsmodel/priormodel.py +++ b/beast/physicsmodel/priormodel.py @@ -3,7 +3,7 @@ from scipy.integrate import quad import astropy.units as u -from beast.physicsmodel.grid_weights_stars import compute_bin_boundaries +from beast.physicsmodel.grid_weights import compute_bin_boundaries import beast.physicsmodel.priormodel_functions as pmfuncs diff --git a/beast/physicsmodel/tests/test_stellar_grid_weights.py b/beast/physicsmodel/tests/test_stellar_grid_weights.py index a304195e2..ae7be50a5 100644 --- a/beast/physicsmodel/tests/test_stellar_grid_weights.py +++ b/beast/physicsmodel/tests/test_stellar_grid_weights.py @@ -1,10 +1,7 @@ import numpy as np -from beast.physicsmodel.grid_weights_stars import ( - compute_distance_grid_weights, - compute_age_grid_weights, - compute_mass_grid_weights, - compute_metallicity_grid_weights, +from beast.physicsmodel.grid_weights import ( + compute_grid_weights, compute_bin_boundaries, ) @@ -26,7 +23,7 @@ def test_age_grid_weights(): Test age grid weights """ ages = np.array([6, 7, 8, 9, 10]) - weights = compute_age_grid_weights(ages) + weights = compute_grid_weights(ages, log=True) expected_weights = [ 4.500045e-04, 4.500045e-03, @@ -44,7 +41,7 @@ def test_mass_grid_weights(): Test mass grid weights """ masses = np.array([1, 2, 5, 7, 10, 30]) - weights = compute_mass_grid_weights(masses) + weights = compute_grid_weights(masses) expected_weights = [ 0.15189873, 0.30379747, @@ -63,7 +60,7 @@ def test_metallicity_grid_weights(): Test metallicities grid weights """ metallicities = np.array([0.03, 0.019, 0.008, 0.004]) - weights = compute_metallicity_grid_weights(metallicities) + weights = compute_grid_weights(metallicities) expected_weights = [1.31343284, 1.31343284, 0.89552239, 0.47761194] np.testing.assert_allclose( weights, expected_weights, err_msg=("Stellar grid metallicity weights error") @@ -77,7 +74,7 @@ def test_flat_distance_grid_weight(): dists = [10.0, 100.0, 1000.0] expected_weights = [0.18181818, 1.0, 1.81818182] - weight = compute_distance_grid_weights(dists) + weight = compute_grid_weights(dists) np.testing.assert_allclose( weight, expected_weights, err_msg=("Stellar grid flat distance weights error") diff --git a/beast/tools/verify_beast_settings.py b/beast/tools/verify_beast_settings.py index 3abd112c9..bd4d4e0a3 100644 --- a/beast/tools/verify_beast_settings.py +++ b/beast/tools/verify_beast_settings.py @@ -22,7 +22,7 @@ def verify_range(param, param_name, param_lim): def check_grid(param, param_name, param_lim): # check if input param limits and grid initialisation make sense - param_min, param_max, param_step = param + param_min, param_max, param_step = param[0:3] if param_min < param_lim[0]: raise ValueError(param_name + " min value not physical.") @@ -68,7 +68,13 @@ def verify_one_input_format(param, param_name, param_format, param_lim): else: raise TypeError(param_name + " is not in the right format - a list.") elif "float" in param_format: - is_list_of_floats = all(isinstance(item, float) for item in param) + if len(param) > 3: + tparam = param[0:3] + if param[3] not in ["log", "lin"]: + raise ValueError(f"4th element in {param_name} is not log or lin") + else: + tparam = param + is_list_of_floats = all(isinstance(item, float) for item in tparam) if not is_list_of_floats: raise TypeError( param_name + " is not in the right format - list of floats."