Skip to content

Commit

Permalink
Merge pull request #1275 from DavidT3/refine/revampHydroMassProfilePr…
Browse files Browse the repository at this point in the history
…oduct

Refine/revamp hydro mass profile product
  • Loading branch information
DavidT3 authored Nov 22, 2024
2 parents 511eabf + 4a1a25b commit 63502dd
Show file tree
Hide file tree
Showing 5 changed files with 807 additions and 257 deletions.
10 changes: 6 additions & 4 deletions xga/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This code is a part of X-ray: Generate and Analyse (XGA), a module designed for the XMM Cluster Survey (XCS).
# Last modified by David J Turner (turne540@msu.edu) 29/07/2024, 21:58. Copyright (c) The Contributors
# Last modified by David J Turner (turne540@msu.edu) 20/11/2024, 12:00. Copyright (c) The Contributors

import inspect
from types import FunctionType
Expand All @@ -8,19 +8,21 @@
# it becomes a big inefficiency
from .density import *
from .entropy import *
from .mass import *
from .misc import *
from .sb import *
from .temperature import *

# This dictionary is meant to provide pretty versions of model/function names to go in plots
# This method of merging dictionaries only works in Python 3.5+, but that should be fine
MODEL_PUBLICATION_NAMES = {**DENS_MODELS_PUB_NAMES, **MISC_MODELS_PUB_NAMES, **SB_MODELS_PUB_NAMES,
**TEMP_MODELS_PUB_NAMES, **ENTROPY_MODELS_PUB_NAMES}
**TEMP_MODELS_PUB_NAMES, **ENTROPY_MODELS_PUB_NAMES, **MASS_MODELS_PUB_NAMES}
MODEL_PUBLICATION_PAR_NAMES = {**DENS_MODELS_PAR_NAMES, **MISC_MODELS_PAR_NAMES, **SB_MODELS_PAR_NAMES,
**TEMP_MODELS_PAR_NAMES, **ENTROPY_MODELS_PAR_NAMES}
**TEMP_MODELS_PAR_NAMES, **ENTROPY_MODELS_PAR_NAMES, **MASS_MODELS_PAR_NAMES}
# These dictionaries tell the profile fitting function what models, start pars, and priors are allowed
PROF_TYPE_MODELS = {"brightness": SB_MODELS, "gas_density": DENS_MODELS, "gas_temperature": TEMP_MODELS,
'1d_proj_temperature': TEMP_MODELS, 'specific_entropy': ENTROPY_MODELS}
'1d_proj_temperature': TEMP_MODELS, 'specific_entropy': ENTROPY_MODELS,
'hydrostatic_mass': MASS_MODELS}


def convert_to_odr_compatible(model_func: FunctionType, new_par_name: str = 'β', new_data_name: str = 'x_values') \
Expand Down
66 changes: 54 additions & 12 deletions xga/models/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This code is a part of X-ray: Generate and Analyse (XGA), a module designed for the XMM Cluster Survey (XCS).
# Last modified by David J Turner (turne540@msu.edu) 07/08/2024, 10:14. Copyright (c) The Contributors
# Last modified by David J Turner (turne540@msu.edu) 21/11/2024, 11:58. Copyright (c) The Contributors

import inspect
from abc import ABCMeta, abstractmethod
Expand Down Expand Up @@ -222,7 +222,8 @@ def get_realisations(self, x: Quantity) -> Quantity:
x = x.to(self._x_unit)

if self._x_lims is not None and (np.any(x < self._x_lims[0]) or np.any(x > self._x_lims[1])):
warn("Some x values are outside of the x-axis limits for this model, results may not be trustworthy.")
warn("Some x values are outside of the x-axis limits for this model, results may not be trustworthy.",
stacklevel=2)

if x.isscalar or (not x.isscalar and x.ndim == 1):
realisations = self.model(x[..., None], *self._par_dists)
Expand All @@ -234,7 +235,7 @@ def get_realisations(self, x: Quantity) -> Quantity:
# statement
realisations = self.model(x, *self._par_dists)

return realisations
return realisations.to(self._y_unit)

@staticmethod
@abstractmethod
Expand Down Expand Up @@ -564,7 +565,7 @@ def compare_units(check_pars: List[Quantity], good_pars: List[Quantity]) -> List
:param List[Quantity] good_pars: The second list of parameters, these are taken as having 'correct'
units.
:return: Only if the check pars pass the tests. We return the check pars list but with all elements
converted to EXACTLY the same units as good_pars, not just equivelant.
converted to EXACTLY the same units as good_pars, not just equivalent.
:rtype: List[Quantity]
"""
if len(check_pars) != len(good_pars):
Expand Down Expand Up @@ -659,13 +660,13 @@ def predicted_dist_view(self, radius: Quantity, bins: Union[str, int] = 'auto',
else:
warn("You have not added parameter distributions to this model")

def par_dist_view(self, bins: Union[str, int] = 'auto', colour: str = "lightslategrey"):
def par_dist_view(self, bins: Union[str, int] = 'auto', colour: str = "lightseagreen"):
"""
Very simple method that allows you to view the parameter distributions that have been added to this
model. The model parameter and uncertainties are indicated with red lines, highlighting the value
and enclosing the 1sigma confidence region.
:param Union[str, int] bins: Equivelant to the plt.hist bins argument, set either the number of bins
:param Union[str, int] bins: Equivalent to the plt.hist bins argument, set either the number of bins
or the algorithm to decide on the number of bins.
:param str colour: Set the colour of the histogram.
"""
Expand All @@ -679,33 +680,74 @@ def par_dist_view(self, bins: Union[str, int] = 'auto', colour: str = "lightslat
for ax_ind, ax in enumerate(ax_arr):
# Add histogram
ax.hist(self.par_dists[ax_ind].value, bins=bins, color=colour)
# Add parameter value as a solid red line
ax.axvline(self.model_pars[ax_ind].value, color='red')

# Read out the errors
err = self.model_par_errs[ax_ind]
# Depending how many entries there are per parameter in the error quantity depends how we plot them

# Define the unit of this parameter
cur_unit = err.unit

# Change how we plot depending on how many entries there are per parameter in the error quantity
if err.isscalar:
# Change how we format the value label depending on how big the number is effectively
if (self.model_pars[ax_ind].value / 1000) > 1:
v_ord = len(str(self.model_pars[ax_ind].value).split('.')[0]) - 1
cur_v_str = str((self.model_pars[ax_ind].value / (10**v_ord)).round(2))
cur_e_str = str((err.value / (10**v_ord)).round(2)) + r"\times 10^{" + str(v_ord) + "}"
else:
cur_v_str = str(self.model_pars[ax_ind].round(2).value)
cur_e_str = str(err.round(2).value)

# Set up the label that will accompany the vertical lines to indicate parameter value and error
vals_label = cur_v_str + r"\pm" + cur_e_str

ax.axvline(self.model_pars[ax_ind].value - err.value, color='red', linestyle='dashed')
ax.axvline(self.model_pars[ax_ind].value + err.value, color='red', linestyle='dashed')
elif not err.isscalar and len(err) == 2:

# Change how we format the value label depending on how big the number is effectively
if (self.model_pars[ax_ind].value / 1000) > 1:
v_ord = len(str(self.model_pars[ax_ind].value).split('.')[0]) - 1

cur_v_str = str((self.model_pars[ax_ind].value / (10 ** v_ord)).round(2))
cur_em_str = str((err[0].value / (10 ** v_ord)).round(2))
cur_ep_str = str((err[1].value / (10 ** v_ord)).round(2))
# Set up the label that will accompany the vertical lines to indicate parameter value and error
vals_label = (cur_v_str + "^{+" + cur_ep_str + "}_{-" + cur_em_str + "} " +
r"\times 10^{" + str(v_ord) + "}")
else:
cur_v_str = str(self.model_pars[ax_ind].round(2).value)
cur_em_str = str(err[0].round(2).value)
cur_ep_str = str(err[1].round(2).value)
# Set up the label that will accompany the vertical lines to indicate parameter value and error
vals_label = (cur_v_str + "^{+" + cur_ep_str + "}_{-" + cur_em_str + "}")

ax.axvline(self.model_pars[ax_ind].value - err[0].value, color='red', linestyle='dashed')
ax.axvline(self.model_pars[ax_ind].value + err[1].value, color='red', linestyle='dashed')
else:
raise ValueError("Parameter error has three elements in it!")

cur_unit = err.unit
# The full label for the vertical line that indicates the parameter value
res_label = (r"$" + self.par_publication_names[ax_ind].replace('$', '') + "= "
+ vals_label + cur_unit.to_string("latex").strip("$") + '$')

# Add parameter value as a solid red line
ax.axvline(self.model_pars[ax_ind].value, color='red', label=res_label)

if cur_unit == Unit(''):
par_unit_name = ""
else:
par_unit_name = r" $\left[" + cur_unit.to_string("latex").strip("$") + r"\right]$"

ax.set_xlabel(self.par_publication_names[ax_ind] + par_unit_name)
ax.set_xlabel(self.par_publication_names[ax_ind] + par_unit_name, fontsize=14)
ax.legend(loc='best')


# And show the plot
plt.tight_layout()
plt.show()
else:
warn("You have not added parameter distributions to this model")
warn("You have not added parameter distributions to this model", stacklevel=2)

def view(self, radii: Quantity = None, xscale: str = 'log', yscale: str = 'log', figsize: tuple = (8, 8),
colour: str = "black"):
Expand Down
110 changes: 110 additions & 0 deletions xga/models/mass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# This code is a part of X-ray: Generate and Analyse (XGA), a module designed for the XMM Cluster Survey (XCS).
# Last modified by David J Turner (turne540@msu.edu) 20/11/2024, 22:32. Copyright (c) The Contributors

from typing import Union, List

import numpy as np
from astropy.units import Quantity, Unit, UnitConversionError, kpc, deg

from .base import BaseModel1D
from ..utils import r500, r200, r2500


class NFW(BaseModel1D):
"""
A simple model to fit galaxy cluster mass profiles (https://ui.adsabs.harvard.edu/abs/1997ApJ...490..493N/abstract)
- a cumulative mass version of the Navarro-Frenk-White profile. Typically, the NFW is formulated in terms of mass
density, but one can derive a mass profile from it (https://ui.adsabs.harvard.edu/abs/2006MNRAS.368..518V/abstract).
The NFW is extremely widely used, though generally for dark matter mass profiles, but will act as a handy
functional form to fit to data-driven mass profiles derived from X-ray observations of clusters.
:param Unit/str x_unit: The unit of the x-axis of this model, kpc for instance. May be passed as a string
representation or an astropy unit object.
:param Unit/str y_unit: The unit of the output of this model, Msun for instance. May be passed as a string
representation or an astropy unit object.
:param List[Quantity] cust_start_pars: The start values of the model parameters for any fitting function that
used start values. The units are checked against default start values.
"""
def __init__(self, x_unit: Union[str, Unit] = 'kpc', y_unit: Union[str, Unit] = Unit('Msun'),
cust_start_pars: List[Quantity] = None):
"""
The init of a subclass of the XGA BaseModel1D class, describing the shape of cumulative mass profiles for
a galaxy cluster based on the NFW mass density profile.
"""

# If a string representation of a unit was passed then we make it an astropy unit
if isinstance(x_unit, str):
x_unit = Unit(x_unit)
if isinstance(y_unit, str):
y_unit = Unit(y_unit)

poss_y_units = [Unit('Msun')]
y_convertible = [u.is_equivalent(y_unit) for u in poss_y_units]
if not any(y_convertible):
allowed = ", ".join([u.to_string() for u in poss_y_units])
raise UnitConversionError("{p} is not convertible to any of the allowed units; "
"{a}".format(p=y_unit.to_string(), a=allowed))
else:
yu_ind = y_convertible.index(True)

poss_x_units = [kpc, deg, r200, r500, r2500]
x_convertible = [u.is_equivalent(x_unit) for u in poss_x_units]
if not any(x_convertible):
allowed = ", ".join([u.to_string() for u in poss_x_units])
raise UnitConversionError("{p} is not convertible to any of the allowed units; "
"{a}".format(p=x_unit.to_string(), a=allowed))
else:
xu_ind = x_convertible.index(True)

r_scale_starts = [Quantity(100, 'kpc'), Quantity(0.2, 'deg'), Quantity(0.05, r200), Quantity(0.1, r500),
Quantity(0.5, r2500)]
# We will implement the NFW mass profile with a rho_0 normalization parameter, a density - and leave in the
# volume integration terms - rather than fitting for some mass normalization
norm_starts = [Quantity(1e+13, 'Msun/Mpc^3')]

start_pars = [r_scale_starts[xu_ind], norm_starts[yu_ind]]
if cust_start_pars is not None:
# If the custom start parameters can run this gauntlet without tripping an error then we're all good
# This method also returns the custom start pars converted to exactly the same units as the default
start_pars = self.compare_units(cust_start_pars, start_pars)

r_core_priors = [{'prior': Quantity([0, 2000], 'kpc'), 'type': 'uniform'},
{'prior': Quantity([0, 1], 'deg'), 'type': 'uniform'},
{'prior': Quantity([0, 1], r200), 'type': 'uniform'},
{'prior': Quantity([0, 1], r500), 'type': 'uniform'},
{'prior': Quantity([0, 1], r2500), 'type': 'uniform'}]
norm_priors = [{'prior': Quantity([1e+12, 1e+16], 'Msun/Mpc^3'), 'type': 'uniform'}]

priors = [r_core_priors[xu_ind], norm_priors[yu_ind]]

nice_pars = [r"R$_{\rm{s}}$", r"$\rho_{0}$"]
info_dict = {'author': 'Navarro J, Frenk C, White S', 'year': '1997',
'reference': 'https://ui.adsabs.harvard.edu/abs/1997ApJ...490..493N/abstract',
'general': 'The cumulative mass version of the NFW mass-density profile for galaxy \n'
'clusters - normally used to describe dark matter profiles.'}

super().__init__(x_unit, y_unit, start_pars, priors, 'nfw', 'NFW Profile', nice_pars, 'Mass',
info_dict)

@staticmethod
def model(x: Quantity, r_scale: Quantity, rho_zero: Quantity) -> Quantity:
"""
The model function for the constant-core and power-law entropy model.
:param Quantity x: The radii to calculate y values for.
:param Quantity r_scale: The scale radius parameter.
:param Quantity rho_zero: A density normalization parameter.
:return: The y values corresponding to the input x values.
:rtype: Quantity
"""

norm_rad = x / r_scale
result = 4*np.pi*rho_zero*np.power(r_scale, 3)*(np.log(1 + norm_rad) - (norm_rad / (1 + norm_rad)))
return result


# So that things like fitting functions can be written generally to support different models
MASS_MODELS = {"nfw": NFW}
MASS_MODELS_PUB_NAMES = {n: m().publication_name for n, m in MASS_MODELS.items()}
MASS_MODELS_PAR_NAMES = {n: m().par_publication_names for n, m in MASS_MODELS.items()}
Loading

0 comments on commit 63502dd

Please sign in to comment.