Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify feature bounds using np.nan instead of masking #283

Merged
merged 13 commits into from
May 13, 2024
28 changes: 15 additions & 13 deletions pahfit/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import astropy.units as u
from pkg_resources import resource_filename
from pahfit.errors import PAHFITFeatureError
from pahfit.features.features_format import BoundedMaskedColumn, BoundedParTableFormatter
from pahfit.features.features_format import BoundedParTableFormatter


class UniqueKeyLoader(yaml.SafeLoader):
Expand Down Expand Up @@ -69,8 +69,8 @@ def value_bounds(val, bounds):
Returns:
-------

The value, if unbounded, or a 3 element tuple (value, min, max).
Any missing bound is replaced with the numpy `masked' value.
A 3 element tuple (value, min, max).
Any missing bound is replaced with the numpy.nan value.

Raises:
-------
Expand All @@ -82,7 +82,7 @@ def value_bounds(val, bounds):
if val is None:
val = np.ma.masked
if not bounds:
return (val,) + 2 * (np.ma.masked,) # Fixed
return (val,) + 2 * (np.nan,) # (val,nan,nan) indicates fixed
jdtsmith marked this conversation as resolved.
Show resolved Hide resolved
ret = [val]
for i, b in enumerate(bounds):
if isinstance(b, str):
Expand All @@ -109,7 +109,6 @@ class Features(Table):
"""

TableFormatter = BoundedParTableFormatter
MaskedColumn = BoundedMaskedColumn

param_covar = TableAttribute(default=[])
_kind_params = {'starlight': {'temperature', 'tau'},
Expand All @@ -122,7 +121,8 @@ class Features(Table):
_units = {'temperature': u.K, 'wavelength': u.um, 'fwhm': u.um}
_group_attrs = set(('bounds', 'features', 'kind')) # group-level attributes
_param_attrs = set(('value', 'bounds')) # Each parameter can have these attributes
_no_bounds = set(('name', 'group', 'geometry', 'model')) # String attributes (no bounds)
_no_bounds = set(('name', 'group', 'kind', 'geometry', 'model')) # str attributes (no bounds)
_bounds_dtype = np.dtype([("val", "f4"), ("min", "f4"), ("max", "f4")])

@classmethod
def read(cls, file, *args, **kwargs):
Expand Down Expand Up @@ -308,19 +308,17 @@ def _construct_table(cls, inp: dict):
tables = []
for (kind, features) in inp.items():
kind_params = cls._kind_params[kind] # All params for this kind
rows = []
rows, dtypes = [], []
for (name, params) in features.items():
for missing in kind_params - params.keys():
if missing in cls._no_bounds:
params[missing] = 0.0
else:
params[missing] = value_bounds(0.0, bounds=(0.0, None))
rows.append(dict(name=name, **params))
table_columns = rows[0].keys()
t = cls(rows, names=table_columns)
for p in cls._kind_params[kind]:
if p not in cls._no_bounds:
t[p].info.format = "0.4g" # Nice format (customized by Formatter)
param_names = rows[0].keys()
dtypes = [str if x in cls._no_bounds else cls._bounds_dtype for x in param_names]
t = cls(rows, names=param_names, dtype=dtypes)
tables.append(t)
tables = vstack(tables)
for cn, col in tables.columns.items():
Expand Down Expand Up @@ -352,8 +350,12 @@ def mask_feature(self, name, mask_value=True):
pass
else:
# mask only the value, not the bounds
row[col_name].mask[0] = mask_value
row[col_name].mask['val'] = mask_value

def unmask_feature(self, name):
"""Remove the mask for all parameters of a feature."""
self.mask_feature(name, mask_value=False)

def _base_repr_(self, *args, **kwargs):
"""Omit dtype on self-print."""
return super()._base_repr_(*args, ** kwargs | dict(show_dtype=False))
60 changes: 26 additions & 34 deletions pahfit/features/features_format.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,45 @@
import numpy.ma as ma
from astropy.table import MaskedColumn
import numpy as np
from astropy.table.pprint import TableFormatter


# * Special table formatting for bounded (val, min, max) values
def fmt_func(fmt):
def _fmt(v):
if ma.is_masked(v[0]):
return " <n/a> "
if ma.is_masked(v[1]):
return f"{v[0]:{fmt}} (Fixed)"
return f"{v[0]:{fmt}} ({v[1]:{fmt}}, {v[2]:{fmt}})"

def fmt_func(fmt: str):
"""Format bounded variables specially."""
if fmt.startswith('%'):
fmt = fmt[1:]

def _fmt(x):
ret = f"{x['val']:{fmt}}"
if np.isnan(x['min']) and np.isnan(x['max']):
return ret + " (fixed)"
else:
mn = ("-∞" if np.isnan(x['min']) or x['min'] == -np.inf
else f"{x['min']:{fmt}}")
mx = ("∞" if np.isnan(x['max']) or x['max'] == np.inf
else f"{x['max']:{fmt}}")
return f"{ret} ({mn}, {mx})"
return _fmt


class BoundedMaskedColumn(MaskedColumn):
"""Masked column which can be toggled to group rows into one item
for formatting. To be set as Table's `MaskedColumn'.
"""

_omit_shape = False

@property
def shape(self):
sh = super().shape
return sh[0:-1] if self._omit_shape and len(sh) > 1 else sh

def is_fixed(self):
return ma.getmask(self)[:, 1:].all(1)


class BoundedParTableFormatter(TableFormatter):
"""Format bounded parameters.
Bounded parameters are 3-field structured arrays, with fields
'var', 'min', and 'max'. To be set as Table's `TableFormatter'.
'val', 'min', and 'max'. To be set as Table's `TableFormatter'.
"""

def _pformat_table(self, table, *args, **kwargs):
bpcols = []
tlfmt = table.meta.get('pahfit_format')
try:
colsh = [(col, col.shape) for col in table.columns.values()]
BoundedMaskedColumn._omit_shape = True
for col, sh in colsh:
if len(sh) == 2 and sh[1] == 3:
for col in table.columns.values():
if len(col.dtype) == 3: # bounded!
bpcols.append((col, col.info.format))
col.info.format = fmt_func(col.info.format or "g")
fmt = col.meta.get('pahfit_format') or tlfmt or "g"
col.info.format = fmt_func(fmt)
return super()._pformat_table(table, *args, **kwargs)
finally:
BoundedMaskedColumn._omit_shape = False
for col, fmt in bpcols:
col.info.format = fmt

def _name_and_structure(self, name, *args):
"Simplified column name: no val, min, max needed."
return name
13 changes: 6 additions & 7 deletions pahfit/features/util.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
"""pahfit.util General pahfit.features utility functions."""
import numpy as np
import numpy.ma as ma


def bounded_is_missing(val):
"""Return a mask array indicating which of the bounded values
are missing. A missing bounded value has a masked value."""
return ma.getmask(val)[..., 0]
return getattr(val['val'], 'mask', None) or np.zeros_like(val['val'], dtype=bool)


def bounded_is_fixed(val):
"""Return a mask array indicating which of the bounded values
are fixed. A fixed bounded value has masked bounds."""
return ma.getmask(val)[..., -2:].all(-1)
return np.isnan(val['min']) & np.isnan(val['max'])


def bounded_min(val):
"""Return the minimum of each bounded value passed.
Either the lower bound, or, if no such bound is set, the value itself."""
lower = val[..., 1]
return np.where(lower, lower, val[..., 0])
lower = val['min']
return np.where(lower, lower, val['val'])


def bounded_max(val):
"""Return the maximum of each bounded value passed.
Either the upper bound, or, if no such bound is set, the value itself."""
upper = val[..., 2]
return np.where(upper, upper, val[..., 0])
upper = val['max']
return np.where(upper, upper, val['val'])
Loading