From d53422fecbccfed622f9d3059c5628d96914b0ae Mon Sep 17 00:00:00 2001 From: JD Smith <93749+jdtsmith@users.noreply.github.com> Date: Wed, 8 May 2024 21:40:49 -0400 Subject: [PATCH] Simplify feature bounds using np.nan instead of masking Partially masked arrays are not well supported, and required elaborate workaround for pretty printing. Here we use a structured array type, and represent "fixed" bounds with np.nan instead of masking. I.e. non is interpreted as the same as negative or positive infinity, unless both bound are nan, which indicates a fixed parameter. --- pahfit/features/features.py | 13 +++++---- pahfit/features/features_format.py | 47 +++++++++--------------------- pahfit/features/util.py | 13 ++++----- 3 files changed, 27 insertions(+), 46 deletions(-) diff --git a/pahfit/features/features.py b/pahfit/features/features.py index da5550e5..1d6d5019 100644 --- a/pahfit/features/features.py +++ b/pahfit/features/features.py @@ -82,7 +82,7 @@ def value_bounds(val, bounds): if val is None: val = np.ma.masked if not bounds: - return (val,) + 2 * (np.ma.masked,) # Fixed + return (val,) + 2 * (np.nan,) # Fixed ret = [val] for i, b in enumerate(bounds): if isinstance(b, str): @@ -123,7 +123,8 @@ class Features(Table): _group_attrs = set(('bounds', 'features', 'kind')) # group-level attributes _param_attrs = set(('value', 'bounds')) # Each parameter can have these attributes _no_bounds = set(('name', 'group', 'geometry', 'model')) # String attributes (no bounds) - + _bounds_dtype = np.dtype([("val", "f4"), ("min", "f4"), ("max", "f4")]) + @classmethod def read(cls, file, *args, **kwargs): """Read a table from file. @@ -308,7 +309,7 @@ def _construct_table(cls, inp: dict): tables = [] for (kind, features) in inp.items(): kind_params = cls._kind_params[kind] # All params for this kind - rows = [] + rows, dtypes = [], [] for (name, params) in features.items(): for missing in kind_params - params.keys(): if missing in cls._no_bounds: @@ -316,8 +317,8 @@ def _construct_table(cls, inp: dict): else: params[missing] = value_bounds(0.0, bounds=(0.0, None)) rows.append(dict(name=name, **params)) - table_columns = rows[0].keys() - t = cls(rows, names=table_columns) + dtypes.append(None if name in cls._no_bounds else cls._bounds_dtypes) + t = cls(rows, names=rows[0].keys(), dtype=dtypes) for p in cls._kind_params[kind]: if p not in cls._no_bounds: t[p].info.format = "0.4g" # Nice format (customized by Formatter) @@ -352,7 +353,7 @@ def mask_feature(self, name, mask_value=True): pass else: # mask only the value, not the bounds - row[col_name].mask[0] = mask_value + row[col_name].mask['val'] = mask_value def unmask_feature(self, name): """Remove the mask for all parameters of a feature.""" diff --git a/pahfit/features/features_format.py b/pahfit/features/features_format.py index 78369727..9526b6ed 100644 --- a/pahfit/features/features_format.py +++ b/pahfit/features/features_format.py @@ -1,53 +1,34 @@ -import numpy.ma as ma -from astropy.table import MaskedColumn +import numpy as np from astropy.table.pprint import TableFormatter - # * Special table formatting for bounded (val, min, max) values def fmt_func(fmt): - def _fmt(v): - if ma.is_masked(v[0]): - return " " - if ma.is_masked(v[1]): - return f"{v[0]:{fmt}} (Fixed)" - return f"{v[0]:{fmt}} ({v[1]:{fmt}}, {v[2]:{fmt}})" - + def _fmt(x): + ret = f"{x['val']:{fmt}}" + if np.isnan(x['min']) and np.isnan(x['max']): + return ret + " (fixed)" + else: + mn = ("-∞" if np.isnan(x['min']) or x['min'] == -np.inf + else f"{x['min']:{fmt}}") + mx = ("∞" if np.isnan(x['max']) or x['max'] == np.inf + else f"{x['max']:{fmt}}") + return f"{ret} ({mn}, {mx})" return _fmt -class BoundedMaskedColumn(MaskedColumn): - """Masked column which can be toggled to group rows into one item - for formatting. To be set as Table's `MaskedColumn'. - """ - - _omit_shape = False - - @property - def shape(self): - sh = super().shape - return sh[0:-1] if self._omit_shape and len(sh) > 1 else sh - - def is_fixed(self): - return ma.getmask(self)[:, 1:].all(1) - - class BoundedParTableFormatter(TableFormatter): """Format bounded parameters. Bounded parameters are 3-field structured arrays, with fields - 'var', 'min', and 'max'. To be set as Table's `TableFormatter'. + 'val', 'min', and 'max'. To be set as Table's `TableFormatter'. """ - def _pformat_table(self, table, *args, **kwargs): bpcols = [] try: - colsh = [(col, col.shape) for col in table.columns.values()] - BoundedMaskedColumn._omit_shape = True - for col, sh in colsh: - if len(sh) == 2 and sh[1] == 3: + for col in table.columns.values(): + if len(col.dtype) == 3 # bounded! bpcols.append((col, col.info.format)) col.info.format = fmt_func(col.info.format or "g") return super()._pformat_table(table, *args, **kwargs) finally: - BoundedMaskedColumn._omit_shape = False for col, fmt in bpcols: col.info.format = fmt diff --git a/pahfit/features/util.py b/pahfit/features/util.py index 68513aa0..92c18020 100644 --- a/pahfit/features/util.py +++ b/pahfit/features/util.py @@ -1,29 +1,28 @@ """pahfit.util General pahfit.features utility functions.""" import numpy as np -import numpy.ma as ma def bounded_is_missing(val): """Return a mask array indicating which of the bounded values are missing. A missing bounded value has a masked value.""" - return ma.getmask(val)[..., 0] + return getattr(a['val'], 'mask', None) or np.zeros_like(a['val'], dtype=bool) def bounded_is_fixed(val): """Return a mask array indicating which of the bounded values are fixed. A fixed bounded value has masked bounds.""" - return ma.getmask(val)[..., -2:].all(-1) + return np.isnan(val['min']) & np.isnan(val['max']) def bounded_min(val): """Return the minimum of each bounded value passed. Either the lower bound, or, if no such bound is set, the value itself.""" - lower = val[..., 1] - return np.where(lower, lower, val[..., 0]) + lower = val['min'] + return np.where(lower, lower, val['val']) def bounded_max(val): """Return the maximum of each bounded value passed. Either the upper bound, or, if no such bound is set, the value itself.""" - upper = val[..., 2] - return np.where(upper, upper, val[..., 0]) + upper = val['max'] + return np.where(upper, upper, val['val'])