diff --git a/CHANGES.rst b/CHANGES.rst index 251cb20f..281070b8 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -37,6 +37,9 @@ Bug Fixes - Fix ``rebuild_fits_rec_dtype`` handling of unsigned integer columns with shapes [#213] +- Fix unit roundtripping when writing to a datamodel with a table + to a FITS file [#242] + Changes to API -------------- diff --git a/src/stdatamodels/jwst/datamodels/tests/test_fits.py b/src/stdatamodels/jwst/datamodels/tests/test_fits.py index 4ba9bb96..5f54b041 100644 --- a/src/stdatamodels/jwst/datamodels/tests/test_fits.py +++ b/src/stdatamodels/jwst/datamodels/tests/test_fits.py @@ -5,7 +5,7 @@ import pytest from stdatamodels.jwst import datamodels -from stdatamodels.jwst.datamodels import ImageModel, JwstDataModel, RampModel +from stdatamodels.jwst.datamodels import ImageModel, JwstDataModel, RampModel, SpecModel @pytest.fixture @@ -91,3 +91,18 @@ def test_resave_duplication_bug(tmp_path): with fits.open(fn1) as ff1, fits.open(fn2) as ff2: assert ff1['ASDF'].size == ff2['ASDF'].size + + +def test_units_roundtrip(tmp_path): + m = SpecModel() + # this next line is required for stdatamodels to cast + # spec_table to a FITS_rec (similar to having data assigned + # to the attribute) + m.spec_table = m.spec_table + m.spec_table.columns['WAVELENGTH'].unit = 'nm' + + fn = tmp_path / "test1.fits" + m.save(fn) + + m = datamodels.open(fn) + assert m.spec_table.columns['WAVELENGTH'].unit == 'nm' diff --git a/src/stdatamodels/properties.py b/src/stdatamodels/properties.py index 113c8e00..caf44f42 100644 --- a/src/stdatamodels/properties.py +++ b/src/stdatamodels/properties.py @@ -83,10 +83,20 @@ def _cast(val, schema): t['shape'] = shape dtype = ndarray.asdf_datatype_to_numpy_dtype(schema['datatype']) + + # save columns in case this is cast back to a fitsrec + if hasattr(val, 'columns'): + cols = val.columns + else: + cols = None val = util.gentle_asarray(val, dtype, allow_extra_columns=allow_extra_columns) if dtype.fields is not None: val = _as_fitsrec(val) + if cols is not None: + for col in cols: + if col.name in val.names and col.unit is not None: + val.columns[col.name].unit = col.unit if 'ndim' in schema and len(val.shape) != schema['ndim']: raise ValueError(