Skip to content

Commit

Permalink
Simplify feature bounds using np.nan instead of masking
Browse files Browse the repository at this point in the history
Partially masked arrays are not well supported, and required elaborate
workaround for pretty printing.  Here we use a structured array type,
and represent "fixed" bounds with np.nan instead of masking.  I.e. non
is interpreted as the same as negative or positive infinity, unless
both bound are nan, which indicates a fixed parameter.
  • Loading branch information
jdtsmith committed May 9, 2024
1 parent 171c645 commit d53422f
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 46 deletions.
13 changes: 7 additions & 6 deletions pahfit/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def value_bounds(val, bounds):
if val is None:
val = np.ma.masked
if not bounds:
return (val,) + 2 * (np.ma.masked,) # Fixed
return (val,) + 2 * (np.nan,) # Fixed
ret = [val]
for i, b in enumerate(bounds):
if isinstance(b, str):
Expand Down Expand Up @@ -123,7 +123,8 @@ class Features(Table):
_group_attrs = set(('bounds', 'features', 'kind')) # group-level attributes
_param_attrs = set(('value', 'bounds')) # Each parameter can have these attributes
_no_bounds = set(('name', 'group', 'geometry', 'model')) # String attributes (no bounds)

_bounds_dtype = np.dtype([("val", "f4"), ("min", "f4"), ("max", "f4")])

@classmethod
def read(cls, file, *args, **kwargs):
"""Read a table from file.
Expand Down Expand Up @@ -308,16 +309,16 @@ def _construct_table(cls, inp: dict):
tables = []
for (kind, features) in inp.items():
kind_params = cls._kind_params[kind] # All params for this kind
rows = []
rows, dtypes = [], []
for (name, params) in features.items():
for missing in kind_params - params.keys():
if missing in cls._no_bounds:
params[missing] = 0.0
else:
params[missing] = value_bounds(0.0, bounds=(0.0, None))
rows.append(dict(name=name, **params))
table_columns = rows[0].keys()
t = cls(rows, names=table_columns)
dtypes.append(None if name in cls._no_bounds else cls._bounds_dtypes)
t = cls(rows, names=rows[0].keys(), dtype=dtypes)
for p in cls._kind_params[kind]:
if p not in cls._no_bounds:
t[p].info.format = "0.4g" # Nice format (customized by Formatter)
Expand Down Expand Up @@ -352,7 +353,7 @@ def mask_feature(self, name, mask_value=True):
pass
else:
# mask only the value, not the bounds
row[col_name].mask[0] = mask_value
row[col_name].mask['val'] = mask_value

def unmask_feature(self, name):
"""Remove the mask for all parameters of a feature."""
Expand Down
47 changes: 14 additions & 33 deletions pahfit/features/features_format.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,34 @@
import numpy.ma as ma
from astropy.table import MaskedColumn
import numpy as np
from astropy.table.pprint import TableFormatter


# * Special table formatting for bounded (val, min, max) values
def fmt_func(fmt):
def _fmt(v):
if ma.is_masked(v[0]):
return " <n/a> "
if ma.is_masked(v[1]):
return f"{v[0]:{fmt}} (Fixed)"
return f"{v[0]:{fmt}} ({v[1]:{fmt}}, {v[2]:{fmt}})"

def _fmt(x):
ret = f"{x['val']:{fmt}}"
if np.isnan(x['min']) and np.isnan(x['max']):
return ret + " (fixed)"
else:
mn = ("-∞" if np.isnan(x['min']) or x['min'] == -np.inf
else f"{x['min']:{fmt}}")
mx = ("∞" if np.isnan(x['max']) or x['max'] == np.inf
else f"{x['max']:{fmt}}")
return f"{ret} ({mn}, {mx})"
return _fmt


class BoundedMaskedColumn(MaskedColumn):
"""Masked column which can be toggled to group rows into one item
for formatting. To be set as Table's `MaskedColumn'.
"""

_omit_shape = False

@property
def shape(self):
sh = super().shape
return sh[0:-1] if self._omit_shape and len(sh) > 1 else sh

def is_fixed(self):
return ma.getmask(self)[:, 1:].all(1)


class BoundedParTableFormatter(TableFormatter):
"""Format bounded parameters.
Bounded parameters are 3-field structured arrays, with fields
'var', 'min', and 'max'. To be set as Table's `TableFormatter'.
'val', 'min', and 'max'. To be set as Table's `TableFormatter'.
"""

def _pformat_table(self, table, *args, **kwargs):
bpcols = []
try:
colsh = [(col, col.shape) for col in table.columns.values()]
BoundedMaskedColumn._omit_shape = True
for col, sh in colsh:
if len(sh) == 2 and sh[1] == 3:
for col in table.columns.values():
if len(col.dtype) == 3 # bounded!
bpcols.append((col, col.info.format))
col.info.format = fmt_func(col.info.format or "g")
return super()._pformat_table(table, *args, **kwargs)
finally:
BoundedMaskedColumn._omit_shape = False
for col, fmt in bpcols:
col.info.format = fmt
13 changes: 6 additions & 7 deletions pahfit/features/util.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
"""pahfit.util General pahfit.features utility functions."""
import numpy as np
import numpy.ma as ma


def bounded_is_missing(val):
"""Return a mask array indicating which of the bounded values
are missing. A missing bounded value has a masked value."""
return ma.getmask(val)[..., 0]
return getattr(a['val'], 'mask', None) or np.zeros_like(a['val'], dtype=bool)


def bounded_is_fixed(val):
"""Return a mask array indicating which of the bounded values
are fixed. A fixed bounded value has masked bounds."""
return ma.getmask(val)[..., -2:].all(-1)
return np.isnan(val['min']) & np.isnan(val['max'])


def bounded_min(val):
"""Return the minimum of each bounded value passed.
Either the lower bound, or, if no such bound is set, the value itself."""
lower = val[..., 1]
return np.where(lower, lower, val[..., 0])
lower = val['min']
return np.where(lower, lower, val['val'])


def bounded_max(val):
"""Return the maximum of each bounded value passed.
Either the upper bound, or, if no such bound is set, the value itself."""
upper = val[..., 2]
return np.where(upper, upper, val[..., 0])
upper = val['max']
return np.where(upper, upper, val['val'])

0 comments on commit d53422f

Please sign in to comment.