Skip to content

Commit

Permalink
fix: SDSS-V SpectrumList loader ambiguity + add: BOSS-only mwm test c…
Browse files Browse the repository at this point in the history
…ases

- added new test cases for BOSS-only mwmVisit and mwmStar files
- added new checks to SpectrumList mwmVisit/mwmStar test to check verified filetype is correct
- forced override on default SpectrumList loaders -- now SpectrumList is no longer ambiguous and doesn't require a format specification
  - relevant areas in tests are updated accordingly
- added print warnings to when HDU is not specified on Spectrum1D loaders for files with multiple spectra.
- ensured tests now remove tempfiles with os.remove
  - arguably, this could work better with tmpfile, but i don't know how tests are deployed on the server-side
  • Loading branch information
rileythai committed Oct 12, 2024
1 parent b6851a6 commit 51e8a2e
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 59 deletions.
15 changes: 11 additions & 4 deletions specutils/io/default_loaders/sdss_v.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,10 @@ def load_sdss_apStar_1D(file_obj, idx: int = 0, **kwargs):


@data_loader(
"SDSS-V apStar multi",
"SDSS-V apStar",
identifier=apStar_identify,
dtype=SpectrumList,
force=True,
priority=10,
extensions=["fits"],
)
Expand Down Expand Up @@ -259,9 +260,10 @@ def load_sdss_apVisit_1D(file_obj, **kwargs):


@data_loader(
"SDSS-V apVisit multi",
"SDSS-V apVisit",
identifier=apVisit_identify,
dtype=SpectrumList,
force=True,
priority=10,
extensions=["fits"],
)
Expand Down Expand Up @@ -338,6 +340,7 @@ def load_sdss_spec_1D(file_obj, *args, hdu: Optional[int] = None, **kwargs):
"""
if hdu is None:
# TODO: how should we handle this -- multiple things in file, but the user cannot choose.
print('HDU not specified. Loading coadd spectrum (HDU1)')
hdu = 1 # defaulting to coadd
# raise ValueError("HDU not specified! Please specify a HDU to load.")
elif hdu in [2, 3, 4]:
Expand All @@ -348,9 +351,10 @@ def load_sdss_spec_1D(file_obj, *args, hdu: Optional[int] = None, **kwargs):


@data_loader(
"SDSS-V spec multi",
"SDSS-V spec",
identifier=spec_sdss5_identify,
dtype=SpectrumList,
force=True,
priority=5,
extensions=["fits"],
)
Expand Down Expand Up @@ -463,14 +467,17 @@ def load_sdss_mwm_1d(file_obj, hdu: Optional[int] = None, **kwargs):
for i in range(len(hdulist)):
if hdulist[i].header.get("DATASUM") != "0":
hdu = i
print('HDU not specified. Loading spectrum at (HDU{})'.
format(i))
break

return _load_mwmVisit_or_mwmStar_hdu(hdulist, hdu, **kwargs)


@data_loader(
"SDSS-V mwm multi",
"SDSS-V mwm",
identifier=mwm_identify,
force=True,
dtype=SpectrumList,
priority=20,
extensions=["fits"],
Expand Down
106 changes: 51 additions & 55 deletions specutils/io/default_loaders/tests/test_sdss_v.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os

import numpy as np
import pytest

from astropy.io import fits
from astropy.units import Unit, Angstrom
from astropy.units import Angstrom, Unit

from specutils import Spectrum1D, SpectrumList

Expand All @@ -25,41 +26,22 @@ def generate_apogee_hdu(observatory="APO", with_wl=True, datasum="0"):
fits.Column(name="apred", array=[b"1.2"], format="3A"),
fits.Column(name="obj", array=[b"2M19534321+6705175"], format="18A"),
fits.Column(name="telescope", array=[b"apo25m"], format="6A"),
fits.Column(name="min_mjd", array=[59804], format="J"),
fits.Column(name="max_mjd", array=[59866], format="J"),
fits.Column(name="n_entries", array=[-1], format="J"),
fits.Column(name="n_visits", array=[5], format="J"),
fits.Column(name="n_good_visits", array=[5], format="J"),
fits.Column(name="n_good_rvs", array=[5], format="J"),
fits.Column(name="snr", array=[46.56802], format="E"),
fits.Column(name="mean_fiber", array=[256.0], format="E"),
fits.Column(name="std_fiber", array=[0.0], format="E"),
fits.Column(name="spectrum_flags", array=[1048576], format="J"),
fits.Column(name="v_rad", array=[-56.7284381], format="E"),
fits.Column(name="e_v_rad", array=[5.35407624], format="E"),
fits.Column(name="std_v_rad", array=[10.79173857], format="E"),
fits.Column(name="median_e_v_rad", array=[16.19418386], format="E"),
fits.Column(name="doppler_teff", array=[7169.0107], format="E"),
fits.Column(name="doppler_e_teff", array=[9.405238], format="E"),
fits.Column(name="doppler_logg", array=[2.981389], format="E"),
fits.Column(name="doppler_e_logg", array=[0.01916536], format="E"),
fits.Column(name="doppler_fe_h", array=[-1.20532212], format="E"),
fits.Column(name="doppler_e_fe_h", array=[0.0093738], format="E"),
fits.Column(name="doppler_rchi2", array=[1.1424173], format="E"),
fits.Column(name="doppler_flags", array=[0], format="J"),
fits.Column(name="xcorr_v_rad", array=[np.nan], format="E"),
fits.Column(name="xcorr_v_rel", array=[np.nan], format="E"),
fits.Column(name="xcorr_e_v_rel", array=[np.nan], format="E"),
fits.Column(name="ccfwhm", array=[np.nan], format="E"),
fits.Column(name="autofwhm", array=[np.nan], format="E"),
fits.Column(name="n_components", array=[1], format="J"),
fits.Column(name="snr", array=[50], format="E"),
]
if with_wl:
columns.append(
fits.Column(name="wavelength",
array=wl,
format="8575E",
dim="(8575)"))
columns += [
fits.Column(name="min_mjd", array=[59804], format="J"),
fits.Column(name="max_mjd", array=[59866], format="J"),
]
else:
columns += [
fits.Column(name="mjd", array=[59804], format="J"),
]
columns += [
fits.Column(name="flux", array=flux, format="8575E", dim="(8575)"),
fits.Column(name="ivar", array=ivar, format="8575E", dim="(8575)"),
Expand Down Expand Up @@ -110,32 +92,23 @@ def generate_boss_hdu(observatory="APO", with_wl=True, datasum="0"):
fits.Column(name="sdss_id", array=[42], format="K"),
fits.Column(name="run2d", array=["6_1_2"], format="6A"),
fits.Column(name="telescope", array=["apo25m"], format="6A"),
fits.Column(name="min_mjd", array=[54], format="J"),
fits.Column(name="max_mjd", array=[488], format="J"),
fits.Column(name="n_visits", array=[1], format="J"),
fits.Column(name="n_good_visits", array=[1], format="J"),
fits.Column(name="n_good_rvs", array=[1], format="J"),
fits.Column(name="v_rad", array=[0], format="E"),
fits.Column(name="e_v_rad", array=[1], format="E"),
fits.Column(name="std_v_rad", array=[1], format="E"),
fits.Column(name="median_e_v_rad", array=[3], format="E"),
fits.Column(name="xcsao_teff", array=[5000], format="E"),
fits.Column(name="xcsao_e_teff", array=[10], format="E"),
fits.Column(name="xcsao_logg", array=[4], format="E"),
fits.Column(name="xcsao_e_logg", array=[3], format="E"),
fits.Column(name="xcsao_fe_h", array=[0], format="E"),
fits.Column(name="xcsao_e_fe_h", array=[5], format="E"),
fits.Column(name="xcsao_meanrxc", array=[0], format="E"),
fits.Column(name="snr", array=[50], format="E"),
fits.Column(name="gri_gaia_transform_flags", array=[1], format="J"),
fits.Column(name="zwarning_flags", array=[0], format="J"),
]

if with_wl:
columns.append(
fits.Column(name="wavelength",
array=wl,
format="4648E",
dim="(4648)"))
columns += [
fits.Column(name="min_mjd", array=[54], format="J"),
fits.Column(name="max_mjd", array=[488], format="J"),
]
else:
columns += [
fits.Column(name="mjd", array=[59804], format="J"),
]
columns += [
fits.Column(name="flux", array=flux, format="4648E", dim="(4648)"),
fits.Column(name="ivar", array=ivar, format="4648E", dim="(4648)"),
Expand Down Expand Up @@ -483,21 +456,28 @@ def test_mwm_1d(file_obj, hdu, with_wl, hduflags):
assert len(data.flux.value) == length
assert data.spectral_axis.unit == Angstrom
assert data.flux.unit == Unit("1e-17 erg / (s cm2 Angstrom)")
os.remove(tmpfile)


@pytest.mark.parametrize(
"file_obj, with_wl, hduflags",
[
("mwm-temp", False, [0, 0, 1, 1]),
("mwm-temp", True, [0, 0, 1, 1]),
("mwm-temp", False, [0, 1, 1, 0]),
("mwm-temp", True, [0, 1, 1, 0]),
("mwm-temp", False, [1, 1, 0, 0]),
("mwm-temp", True, [1, 1, 0, 0]),
("mwm-temp", False, [1, 1, 1, 1]),
("mwm-temp", True, [0, 1, 0, 1]),
("mwm-temp", True, [1, 1, 1, 1]),
],
)
def test_mwm_list(file_obj, with_wl, hduflags):
"""Test mwm SpectrumList loader"""
tmpfile = str(file_obj) + ".fits"
mwm_HDUList(hduflags, with_wl).writeto(tmpfile, overwrite=True)

data = SpectrumList.read(tmpfile, format="SDSS-V mwm multi")
data = SpectrumList.read(tmpfile)
assert isinstance(data, SpectrumList)
for i in range(len(data)):
assert isinstance(data[i], Spectrum1D)
Expand All @@ -509,10 +489,15 @@ def test_mwm_list(file_obj, with_wl, hduflags):
else:
raise ValueError(
"INSTRMNT tag in test HDU header is not set properly.")
if with_wl:
assert data[i].meta['datatype'].lower() == 'mwmstar'
else:
assert data[i].meta['datatype'].lower() == 'mwmvisit'
assert len(data[i].spectral_axis.value) == length
assert len(data[i].flux.value) == length
assert data[i].spectral_axis.unit == Angstrom
assert data[i].flux.unit == Unit("1e-17 erg / (s cm2 Angstrom)")
os.remove(tmpfile)


@pytest.mark.parametrize(
Expand All @@ -530,6 +515,7 @@ def test_mwm_1d_fail_spec(file_obj, hdu, hduflags):
mwm_HDUList(hduflags, True).writeto(tmpfile, overwrite=True)
with pytest.raises(IndexError):
Spectrum1D.read(tmpfile, hdu=hdu)
os.remove(tmpfile)


@pytest.mark.parametrize(
Expand All @@ -546,6 +532,7 @@ def test_mwm_1d_fail(file_obj, with_wl):

with pytest.raises(ValueError):
Spectrum1D.read(tmpfile)
os.remove(tmpfile)


@pytest.mark.parametrize(
Expand All @@ -561,7 +548,8 @@ def test_mwm_list_fail(file_obj, with_wl):
mwm_HDUList([0, 0, 0, 0], with_wl).writeto(tmpfile, overwrite=True)

with pytest.raises(ValueError):
SpectrumList.read(tmpfile, format="SDSS-V mwm multi")
SpectrumList.read(tmpfile)
os.remove(tmpfile)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -589,6 +577,7 @@ def test_spec_1d(file_obj, n_spectra):
assert len(data.mask) == 10

assert data[i].meta["header"].get("foobar") == "barfoo"
os.remove(tmpfile)


@pytest.mark.parametrize(
Expand All @@ -603,7 +592,7 @@ def test_spec_list(file_obj, n_spectra):
tmpfile = str(file_obj) + ".fits"
spec_HDUList(n_spectra).writeto(tmpfile, overwrite=True)

data = SpectrumList.read(tmpfile, format="SDSS-V spec multi")
data = SpectrumList.read(tmpfile)
assert isinstance(data, SpectrumList)
assert len(data) == n_spectra + 1
for i in range(n_spectra):
Expand All @@ -613,6 +602,7 @@ def test_spec_list(file_obj, n_spectra):
assert data[i].spectral_axis.unit == Angstrom
assert len(data[i].mask) == 10
assert data[i].meta["header"].get("foobar") == "barfoo"
os.remove(tmpfile)


@pytest.mark.parametrize(
Expand All @@ -631,6 +621,7 @@ def test_spec_1d_fail_hdu(file_obj, hdu):

with pytest.raises(ValueError):
Spectrum1D.read(tmpfile, hdu=hdu)
os.remove(tmpfile)


@pytest.mark.parametrize(
Expand All @@ -653,6 +644,7 @@ def test_apStar_1D(file_obj, idx):
assert data.spectral_axis.unit == Angstrom

assert data.meta["header"].get("foobar") == "barfoo"
os.remove(tmpfile)


@pytest.mark.parametrize(
Expand All @@ -666,7 +658,7 @@ def test_apStar_list(file_obj, n_spectra):
tmpfile = str(file_obj) + ".fits"
apStar_HDUList(n_spectra).writeto(tmpfile, overwrite=True)

data = SpectrumList.read(tmpfile, format="SDSS-V apStar multi")
data = SpectrumList.read(tmpfile)
assert isinstance(data, SpectrumList)
assert len(data) == n_spectra
for i in range(len(data)):
Expand All @@ -675,6 +667,7 @@ def test_apStar_list(file_obj, n_spectra):
assert data[i].flux.unit == Unit("1e-17 erg / (s cm2 Angstrom)")
assert len(data[i].spectral_axis.value) == 10
assert data[i].meta["header"].get("foobar") == "barfoo"
os.remove(tmpfile)


@pytest.mark.parametrize(
Expand All @@ -689,7 +682,8 @@ def test_apStar_fail_list(file_obj):
apStar_HDUList(1).writeto(tmpfile, overwrite=True)

with pytest.raises(ValueError):
SpectrumList.read(tmpfile, format="SDSS-V apStar multi")
SpectrumList.read(tmpfile)
os.remove(tmpfile)


@pytest.mark.parametrize(
Expand All @@ -707,6 +701,7 @@ def test_apVisit_1D(file_obj):
assert np.array_equal(data.spectral_axis.value, np.arange(1, 31, 1))
assert len(data.flux.value) == 30
assert data.meta["header"].get("foobar") == "barfoo"
os.remove(tmpfile)


@pytest.mark.parametrize(
Expand All @@ -719,6 +714,7 @@ def test_apVisit_list(file_obj):
tmpfile = str(file_obj) + ".fits"
apVisit_HDUList().writeto(tmpfile, overwrite=True)

data = SpectrumList.read(tmpfile, format="SDSS-V apVisit multi")
data = SpectrumList.read(tmpfile)
assert isinstance(data, SpectrumList)
assert len(data) == 3
os.remove(tmpfile)

0 comments on commit 51e8a2e

Please sign in to comment.