Skip to content

Commit

Permalink
add: tests for astra default loaders
Browse files Browse the repository at this point in the history
NOTE: this requires the mwmvisit boss fix pr as well
  • Loading branch information
rileythai committed Oct 30, 2024
1 parent 7d78ccd commit c57544c
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 32 deletions.
28 changes: 12 additions & 16 deletions specutils/io/default_loaders/sdss_v.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,6 @@ def load_sdss_mwm_1d(file_obj,
if (np.array(datasums) == 0).all():
raise ValueError("Specified file is empty.")

# TODO: how should we handle this -- multiple things in file, but the user cannot choose.
if hdu is None:
for i in range(1, len(hdulist)):
if hdulist[i].header.get("DATASUM") != "0":
Expand Down Expand Up @@ -660,11 +659,13 @@ def load_astra_1d(file_obj, hdu: Optional[int] = None, **kwargs):
if (np.array(datasums) == 0).all():
raise ValueError("Specified file is empty.")

# TODO: how should we handle this -- multiple things in file, but the user cannot choose.
if hdu is None:
for i in range(len(hdulist)):
for i in range(1, len(hdulist)):
if hdulist[i].header.get("DATASUM") != "0":
hdu = i
warnings.warn(
'HDU not specified. Loading spectrum at (HDU{})'.
format(i), AstropyUserWarning)
break

return _load_astra_hdu(hdulist, hdu, **kwargs)
Expand Down Expand Up @@ -705,10 +706,6 @@ def load_astra_list(file_obj, **kwargs):
for hdu in range(1, len(hdulist)):
if hdulist[hdu].header.get("DATASUM") == "0":
# Skip zero data HDU's
# TODO: validate if we want this printed warning or not.
# it might get annoying & fill logs with useless alerts.
print("WARNING: HDU{} ({}) is empty.".format(
hdu, hdulist[hdu].name))
continue
spectra.append(_load_astra_hdu(hdulist, hdu))
return spectra
Expand Down Expand Up @@ -774,26 +771,25 @@ def _load_astra_hdu(hdulist: HDUList, hdu: int, **kwargs):
meta = dict()
meta["header"] = hdulist[0].header

# Add SNR to metadata
meta["snr"] = np.array(hdulist[hdu].data["snr"])

# Add identifiers (obj, telescope, mjd, datatype)
# TODO: need to see what metadata we're interested in for the MWM files.
#meta["telescope"] = hdulist[hdu].data["telescope"]
#meta["instrument"] = hdulist[hdu].header.get("INSTRMNT")
try:
meta["telescope"] = hdulist[hdu].data["telescope"]
meta['instrument'] = 'BOSS' if hdu <= 2 else 'APOGEE'
try: # get obj if exists
meta["obj"] = hdulist[hdu].data["obj"]
except KeyError:
pass

# choose between mwmVisit/Star via KeyError except
try:
meta["mjd"] = hdulist[hdu].data["mjd"]
meta['mjd'] = hdulist[hdu].data['mjd']
meta["datatype"] = "astraVisit"
except KeyError:
meta['min_mjd'] = str(hdulist[hdu].data["min_mjd"][0])
meta["min_mjd"] = str(hdulist[hdu].data["min_mjd"][0])
meta["max_mjd"] = str(hdulist[hdu].data["max_mjd"][0])
meta["datatype"] = "astraStar"
finally:
meta["name"] = hdulist[hdu].name
meta["sdss_id"] = hdulist[hdu].data['sdss_id']

return Spectrum1D(
spectral_axis=spectral_axis,
Expand Down
200 changes: 184 additions & 16 deletions specutils/io/default_loaders/tests/test_sdss_v.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
def generate_apogee_hdu(observatory="APO",
with_wl=True,
datasum="0",
nvisits=1):
nvisits=1,
astra=False):
wl = (10**(4.179 + 6e-6 * np.arange(8575))).reshape((1, -1))
flux = np.array([np.zeros_like(wl)] * nvisits)
ivar = np.array([np.zeros_like(wl)] * nvisits)
Expand All @@ -23,15 +24,14 @@ def generate_apogee_hdu(observatory="APO",

columns = [
fits.Column(name="spectrum_pk_id", array=[159783564], format="K"),
fits.Column(name="release", array=[b"sdss5"], format="5A"),
fits.Column(name="filetype", array=[b"apStar"], format="6A"),
fits.Column(name="v_astra", array=[b"0.5.0"], format="5A"),
fits.Column(name="release", array=[b"sdss5"]*nvisits, format="5A"),
fits.Column(name="v_astra", array=[b"0.5.0"]*nvisits, format="5A"),
fits.Column(name="healpix", array=[3], format="J"),
fits.Column(name="sdss_id", array=[42], format="K"),
fits.Column(name="apred", array=[b"1.2"], format="3A"),
fits.Column(name="sdss_id", array=[42]*nvisits, format="K"),
fits.Column(name="apred", array=[b"1.2"]*nvisits, format="3A"),
fits.Column(name="obj", array=[b"2M19534321+6705175"], format="18A"),
fits.Column(name="telescope", array=[b"apo25m"], format="6A"),
fits.Column(name="snr", array=[50], format="E"),
fits.Column(name="telescope", array=[b"apo25m"]*nvisits, format="6A"),
fits.Column(name="snr", array=[50]*nvisits, format="E"),
]
if with_wl:
columns.append(
Expand All @@ -47,8 +47,10 @@ def generate_apogee_hdu(observatory="APO",
columns += [
fits.Column(name="mjd", array=[59804], format="J"),
]
flux_col = 'model_flux' if astra else 'flux'

columns += [
fits.Column(name="flux", array=flux, format="8575E", dim="(8575)"),
fits.Column(name=flux_col, array=flux, format="8575E", dim="(8575)"),
fits.Column(name="ivar", array=ivar, format="8575E", dim="(8575)"),
fits.Column(name="pixel_flags",
array=pixel_flags,
Expand Down Expand Up @@ -84,23 +86,26 @@ def generate_apogee_hdu(observatory="APO",
return fits.BinTableHDU.from_columns(columns, header=header)


def generate_boss_hdu(observatory="APO", with_wl=True, datasum="0", nvisits=1):
def generate_boss_hdu(observatory="APO",
with_wl=True,
datasum="0",
nvisits=1,
astra=False):
wl = (10**(3.5523 + 1e-4 * np.arange(4648))).reshape((1, -1))
flux = np.array([np.zeros_like(wl)] * nvisits)
ivar = np.array([np.zeros_like(wl)] * nvisits)
pixel_flags = np.array([np.zeros_like(wl)] * nvisits)
continuum = np.array([np.zeros_like(wl)] * nvisits)
nmf_rectified_model_flux = np.array([np.zeros_like(wl)] * nvisits)
columns = [
fits.Column(name="spectrum_pk_id", array=[0], format="K"),
fits.Column(name="release", array=["sdss5"], format="5A"),
fits.Column(name="filetype", array=["specFull"], format="7A"),
fits.Column(name="spectrum_pk_id", array=[0]*nvisits, format="K"),
fits.Column(name="release", array=["sdss5"]*nvisits, format="5A"),
fits.Column(name="v_astra", array=["0.5.0"], format="5A"),
fits.Column(name="healpix", array=[34], format="J"),
fits.Column(name="sdss_id", array=[42], format="K"),
fits.Column(name="run2d", array=["6_1_2"], format="6A"),
fits.Column(name="telescope", array=["apo25m"], format="6A"),
fits.Column(name="snr", array=[50], format="E"),
fits.Column(name="telescope", array=["apo25m"]*nvisits, format="6A"),
fits.Column(name="snr", array=[50]*nvisits, format="E"),
]

if with_wl:
Expand All @@ -117,8 +122,9 @@ def generate_boss_hdu(observatory="APO", with_wl=True, datasum="0", nvisits=1):
columns += [
fits.Column(name="mjd", array=[59804], format="J"),
]
flux_col = 'model_flux' if astra else 'flux'
columns += [
fits.Column(name="flux", array=flux, format="4648E", dim="(4648)"),
fits.Column(name=flux_col, array=flux, format="4648E", dim="(4648)"),
fits.Column(name="ivar", array=ivar, format="4648E", dim="(4648)"),
fits.Column(name="pixel_flags",
array=pixel_flags,
Expand Down Expand Up @@ -464,7 +470,50 @@ def test_mwm_1d_nohdu(file_obj, hdu, with_wl, hduflags, nvisits):
assert data.flux.value.shape[-1] == length
if nvisits > 1:
assert data.flux.value.shape[0] == nvisits
if with_wl:
assert data.meta['datatype'].lower() == 'mwmstar'
else:
assert data.meta['datatype'].lower() == 'mwmvisit'
assert data.spectral_axis.unit == Angstrom
assert data.flux.unit == Unit("1e-17 erg / (s cm2 Angstrom)")
os.remove(tmpfile)


@pytest.mark.parametrize(
"file_obj, hdu, with_wl, hduflags, nvisits",
[
("mwm-temp", None, False, [0, 0, 1, 0], 1), # visit
("mwm-temp", None, False, [0, 1, 1, 0], 3), # multi-ext visits
("mwm-temp", None, True, [0, 0, 1, 0], 1), # star
("mwm-temp", None, True, [0, 1, 1, 0], 1),
],
)
def test_astra_nohdu(file_obj, hdu, with_wl, hduflags, nvisits):
"""Test astra Spectrum1D loader when HDU isn't specified"""
tmpfile = str(file_obj) + ".fits"
mwm_HDUList(hduflags, with_wl, nvisits=nvisits,
astra=True).writeto(tmpfile, overwrite=True)

with pytest.warns(AstropyUserWarning):
data = Spectrum1D.read(tmpfile, hdu=hdu)
assert isinstance(data, Spectrum1D)
assert isinstance(data.meta["header"], fits.Header)

if data.meta["instrument"].lower() == "apogee":
length = 8575
elif data.meta["instrument"].lower() == "boss":
length = 4648
else:
raise ValueError(
"INSTRMNT tag in test HDU header is not set properly.")
assert len(data.spectral_axis.value) == length
assert data.flux.value.shape[-1] == length
if nvisits > 1:
assert data.flux.value.shape[0] == nvisits
if with_wl:
assert data.meta['datatype'].lower() == 'astrastar'
else:
assert data.meta['datatype'].lower() == 'astravisit'
assert data.spectral_axis.unit == Angstrom
assert data.flux.unit == Unit("1e-17 erg / (s cm2 Angstrom)")
os.remove(tmpfile)
Expand All @@ -486,6 +535,44 @@ def test_mwm_1d(file_obj, hdu, with_wl, hduflags, nvisits):
mwm_HDUList(hduflags, with_wl, nvisits=nvisits).writeto(tmpfile,
overwrite=True)

data = Spectrum1D.read(tmpfile, hdu=hdu)
assert isinstance(data, Spectrum1D)
assert isinstance(data.meta["header"], fits.Header)
if data.meta["instrument"].lower() == "apogee":
length = 8575
elif data.meta["instrument"].lower() == "boss":
length = 4648
else:
raise ValueError(
"INSTRMNT tag in test HDU header is not set properly.")
assert len(data.spectral_axis.value) == length
assert data.flux.value.shape[-1] == length
if nvisits > 1:
assert data.flux.value.shape[0] == nvisits
if with_wl:
assert data.meta['datatype'].lower() == 'mwmstar'
else:
assert data.meta['datatype'].lower() == 'mwmvisit'
assert data.spectral_axis.unit == Angstrom
assert data.flux.unit == Unit("1e-17 erg / (s cm2 Angstrom)")
os.remove(tmpfile)


@pytest.mark.parametrize(
"file_obj, hdu, with_wl, hduflags, nvisits",
[
("astra-temp", 3, False, [0, 0, 1, 0], 1),
("astra-temp", 3, False, [0, 1, 1, 0], 5),
("astra-temp", 3, True, [0, 0, 1, 0], 1),
("astra-temp", 2, True, [0, 1, 1, 0], 1),
],
)
def test_astra_1d(file_obj, hdu, with_wl, hduflags, nvisits):
"""Test astra Spectrum1D loader"""
tmpfile = str(file_obj) + ".fits"
mwm_HDUList(hduflags, with_wl, nvisits=nvisits,
astra=True).writeto(tmpfile, overwrite=True)

data = Spectrum1D.read(tmpfile, hdu=hdu)
assert isinstance(data, Spectrum1D)
assert isinstance(data.meta["header"], fits.Header)
Expand All @@ -501,6 +588,11 @@ def test_mwm_1d(file_obj, hdu, with_wl, hduflags, nvisits):
if nvisits > 1:
assert data.flux.value.shape[0] == nvisits

if with_wl:
assert data.meta['datatype'].lower() == 'astrastar'
else:
assert data.meta['datatype'].lower() == 'astravisit'

assert data.spectral_axis.unit == Angstrom
assert data.flux.unit == Unit("1e-17 erg / (s cm2 Angstrom)")
os.remove(tmpfile)
Expand Down Expand Up @@ -550,6 +642,50 @@ def test_mwm_list(file_obj, with_wl, hduflags):
assert data[i].flux.unit == Unit("1e-17 erg / (s cm2 Angstrom)")
os.remove(tmpfile)

@pytest.mark.parametrize(
"file_obj, with_wl, hduflags",
[
("astra-temp", False, [0, 0, 1, 1]),
("astra-temp", False, [0, 1, 1, 0]),
("astra-temp", False, [1, 1, 0, 0]),
("astra-temp", False, [1, 1, 1, 1]),
("astra-temp", True, [0, 0, 1, 1]),
("astra-temp", True, [0, 1, 1, 0]),
("astra-temp", True, [1, 1, 0, 0]),
("astra-temp", True, [1, 1, 1, 1]),
],
)
def test_astra_list(file_obj, with_wl, hduflags):
"""Test astra SpectrumList loader"""
tmpfile = str(file_obj) + ".fits"
nvisits = 1 if with_wl else 3
mwm_HDUList(hduflags, with_wl, nvisits=nvisits,astra=True).writeto(tmpfile,
overwrite=True)

data = SpectrumList.read(tmpfile)
assert isinstance(data, SpectrumList)
for i in range(len(data)):
assert isinstance(data[i], Spectrum1D)
assert isinstance(data[i].meta["header"], fits.Header)
if data[i].meta["instrument"].lower() == "apogee":
length = 8575
elif data[i].meta["instrument"].lower() == "boss":
length = 4648
else:
raise ValueError(
"INSTRMNT tag in test HDU header is not set properly.")
if with_wl:
assert data[i].meta['datatype'].lower() == 'astrastar'
else:
assert data[i].meta['datatype'].lower() == 'astravisit'
assert len(data[i].spectral_axis.value) == length
assert data[i].flux.value.shape[-1] == length
if nvisits > 1:
assert data[i].flux.value.shape[0] == nvisits
assert data[i].spectral_axis.unit == Angstrom
assert data[i].flux.unit == Unit("1e-17 erg / (s cm2 Angstrom)")
os.remove(tmpfile)


@pytest.mark.parametrize(
"file_obj, hdu, hduflags",
Expand Down Expand Up @@ -585,6 +721,22 @@ def test_mwm_1d_fail(file_obj, with_wl):
Spectrum1D.read(tmpfile)
os.remove(tmpfile)

@pytest.mark.parametrize(
"file_obj, with_wl",
[
("astra-temp", False),
("astra-temp", True),
],
)
def test_astra_1d_fail(file_obj, with_wl):
"""Test astra Spectrum1D loader fail on empty"""
tmpfile = str(file_obj) + ".fits"
mwm_HDUList([0, 0, 0, 0], with_wl,astra=True).writeto(tmpfile, overwrite=True)

with pytest.raises(ValueError):
Spectrum1D.read(tmpfile)
os.remove(tmpfile)


@pytest.mark.parametrize(
"file_obj, with_wl",
Expand All @@ -602,6 +754,22 @@ def test_mwm_list_fail(file_obj, with_wl):
SpectrumList.read(tmpfile)
os.remove(tmpfile)

@pytest.mark.parametrize(
"file_obj, with_wl",
[
("astra-temp", False),
("astra-temp", True),
],
)
def test_astra_list_fail(file_obj, with_wl):
"""Test astra SpectrumList loader fail on empty"""
tmpfile = str(file_obj) + ".fits"
mwm_HDUList([0, 0, 0, 0], with_wl,astra=True).writeto(tmpfile, overwrite=True)

with pytest.raises(ValueError):
SpectrumList.read(tmpfile)
os.remove(tmpfile)


@pytest.mark.parametrize(
"file_obj,n_spectra",
Expand Down

0 comments on commit c57544c

Please sign in to comment.