Skip to content

Commit

Permalink
Allow for coinciding values for ra range
Browse files Browse the repository at this point in the history
  • Loading branch information
camposandro committed Nov 22, 2024
1 parent 263ad21 commit 9387b06
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 24 deletions.
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ sphinx-autoapi
sphinx-copybutton
sphinx-book-theme
sphinx-design
git+https://github.com/astronomy-commons/hats.git@margin
git+https://github.com/astronomy-commons/hats.git@issue/401/migrate-box-filter
git+https://github.com/astronomy-commons/hats-import.git@margin
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
git+https://github.com/astronomy-commons/hats.git@margin
git+https://github.com/astronomy-commons/hats.git@issue/401/migrate-box-filter
git+https://github.com/lincc-frameworks/nested-pandas.git@main
git+https://github.com/lincc-frameworks/nested-dask.git@main
28 changes: 11 additions & 17 deletions src/lsdb/core/search/box_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,28 +52,22 @@ def box_filter(
Returns:
A new DataFrame with the rows from `data_frame` filtered to only the points inside the box region.
"""
mask = np.ones(len(data_frame), dtype=bool)
if ra is not None:
ra_values = data_frame[metadata.ra_column]
wrapped_ra = np.asarray(wrap_ra_angles(ra_values))
mask_ra = _create_ra_mask(ra, wrapped_ra)
mask = np.logical_and(mask, mask_ra)
if dec is not None:
dec_values = data_frame[metadata.dec_column].to_numpy()
mask_dec = np.logical_and(dec[0] <= dec_values, dec_values <= dec[1])
mask = np.logical_and(mask, mask_dec)
data_frame = data_frame.iloc[mask]
ra_values = data_frame[metadata.ra_column].to_numpy()
dec_values = data_frame[metadata.dec_column].to_numpy()
wrapped_ra = wrap_ra_angles(ra_values)
mask_ra = _create_ra_mask(ra, wrapped_ra)
mask_dec = (dec[0] <= dec_values) & (dec_values <= dec[1])
data_frame = data_frame.iloc[mask_ra & mask_dec]
return data_frame


def _create_ra_mask(ra: tuple[float, float], values: np.ndarray) -> np.ndarray:
"""Creates the mask to filter right ascension values. If this range crosses
the discontinuity line (0 degrees), we have a branched logical operation."""
if ra[0] <= ra[1]:
mask = np.logical_and(ra[0] <= values, values <= ra[1])
if ra[0] == ra[1]:
return np.ones(len(values), dtype=bool)
if ra[0] < ra[1]:
mask = (values >= ra[0]) & (values <= ra[1])
else:
mask = np.logical_or(
np.logical_and(ra[0] <= values, values <= 360),
np.logical_and(0 <= values, values <= ra[1]),
)
mask = ((values >= ra[0]) & (values <= 360)) | ((values >= 0) & (values <= ra[1]))
return mask
25 changes: 20 additions & 5 deletions tests/lsdb/catalog/test_box_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,26 @@ def test_box_search_ra_wrapped_filters_correct_points(small_sky_order1_catalog):
assert np.array_equal(ra_values, filtered_ra_values)


def test_box_search_ra_boundary(small_sky_order1_catalog):
dec = (-40, -30)
ra_column = small_sky_order1_catalog.hc_structure.catalog_info.ra_column
dec_column = small_sky_order1_catalog.hc_structure.catalog_info.dec_column

ra_search_catalog = small_sky_order1_catalog.box_search(ra=(0, 0), dec=dec)
ra_search_df = ra_search_catalog.compute()
ra_values = ra_search_df[ra_column]
dec_values = ra_search_df[dec_column]

assert len(ra_search_df) > 0
assert all((0 <= ra <= 360) for ra in ra_values)
assert all((-40 <= dec <= -30) for dec in dec_values)

for ra_range in [(0, 360), (360, 0)]:
catalog_df = small_sky_order1_catalog.box_search(ra=ra_range, dec=dec).compute()
assert np.array_equal(catalog_df[ra_column], ra_values)
assert np.array_equal(catalog_df[dec_column], dec_values)


def test_box_search_filters_partitions(small_sky_order1_catalog):
ra = (280, 300)
dec = (-40, -30)
Expand Down Expand Up @@ -100,11 +120,6 @@ def test_box_search_invalid_args(small_sky_order1_catalog):
small_sky_order1_catalog.box_search(ra=(0, 30), dec=(-40, -30, 10))
with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_RADEC_RANGE):
small_sky_order1_catalog.box_search(ra=(0, 30, 40), dec=(-40, 10))
# The range values coincide (for ra, values are wrapped)
with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_RADEC_RANGE):
small_sky_order1_catalog.box_search(ra=(100, 100), dec=(-40, -30))
with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_RADEC_RANGE):
small_sky_order1_catalog.box_search(ra=(0, 360), dec=(-40, -30))
with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_RADEC_RANGE):
small_sky_order1_catalog.box_search(ra=(0, 50), dec=(50, 50))
# No range values were provided
Expand Down

0 comments on commit 9387b06

Please sign in to comment.