From fd2707286acfba1125e0e831802ccc3cfb406140 Mon Sep 17 00:00:00 2001 From: Joachim Moeyens Date: Thu, 30 Nov 2023 12:44:15 -0800 Subject: [PATCH] Bug fix: Correctly select the coincident attribution with the lowest distance --- thor/orbits/attribution.py | 20 ++++++++++------ thor/orbits/tests/test_attribution.py | 34 ++++++++++++++------------- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/thor/orbits/attribution.py b/thor/orbits/attribution.py index 78fc8112..83d1fc4c 100644 --- a/thor/orbits/attribution.py +++ b/thor/orbits/attribution.py @@ -54,6 +54,11 @@ def drop_coincident_attributions( # Flatten the table so nested columns are dot-delimited at the top level flattened_table = self.flattened_table() + # Add index to flattened table + flattened_table = flattened_table.add_column( + 0, "index", pa.array(np.arange(len(flattened_table))) + ) + # Drop the residual values (a list column) due to: https://github.com/apache/arrow/issues/32504 flattened_table = flattened_table.drop(["residuals.values"]) @@ -76,17 +81,13 @@ def drop_coincident_attributions( flattened_observations, ["obs_id"], right_keys=["id"] ) - # Add index column - flattened_table = flattened_table.add_column( - 0, "index", pa.array(np.arange(len(flattened_table))) - ) - # Sort the table flattened_table = flattened_table.sort_by( [ ("orbit_id", "ascending"), ("coordinates.time.days", "ascending"), ("coordinates.time.nanos", "ascending"), + ("distance", "ascending"), ] ) @@ -100,7 +101,11 @@ def drop_coincident_attributions( .column("index_first") ) - return self.take(indices) + filtered = self.take(indices) + if filtered.fragmented(): + filtered = qv.defragment(filtered) + + return filtered def attribution_worker( @@ -336,6 +341,8 @@ def attribute_observations( attributions = qv.concatenate(attributions_list) attributions = attributions.sort_by(["orbit_id", "obs_id", "distance"]) + if attributions.fragmented(): + attributions = qv.defragment(attributions) time_end = time.time() logger.info( @@ -459,7 +466,6 @@ def merge_and_extend_orbits( # the same time, keep only observation with smallest distance attributions = attributions.drop_coincident_attributions(observations) - attributions = qv.defragment(attributions) # Create a new orbit members table with the newly attributed observations and # filter the orbits to only include those that still have observations orbit_members_iter = FittedOrbitMembers.from_kwargs( diff --git a/thor/orbits/tests/test_attribution.py b/thor/orbits/tests/test_attribution.py index 11b78d64..6e87c742 100644 --- a/thor/orbits/tests/test_attribution.py +++ b/thor/orbits/tests/test_attribution.py @@ -9,32 +9,34 @@ def test_Attributions_drop_coincident_attributions(): observations = Observations.from_kwargs( - id=["01", "02", "03", "04"], - exposure_id=["e01", "e01", "e02", "e02"], + id=["01", "02", "03", "04", "05"], + exposure_id=["e01", "e01", "e02", "e02", "e02"], coordinates=SphericalCoordinates.from_kwargs( - time=Timestamp.from_mjd([59001.1, 59001.1, 59002.1, 59002.1], scale="utc"), - lon=[1, 2, 3, 4], - lat=[5, 6, 7, 8], - origin=Origin.from_kwargs(code=["500", "500", "500", "500"]), + time=Timestamp.from_mjd( + [59001.1, 59001.1, 59002.1, 59002.1, 59002.1], scale="utc" + ), + lon=[1, 2, 3, 4, 5], + lat=[5, 6, 7, 8, 9], + origin=Origin.from_kwargs(code=["500", "500", "500", "500", "500"]), ), photometry=Photometry.from_kwargs( - filter=["g", "g", "g", "g"], - mag=[10, 11, 12, 13], + filter=["g", "g", "g", "g", "g"], + mag=[10, 11, 12, 13, 14], ), - state_id=[0, 0, 1, 1], + state_id=[0, 0, 1, 1, 1], ) attributions = Attributions.from_kwargs( - orbit_id=["o01", "o01", "o02", "o03"], - obs_id=["01", "02", "03", "03"], - distance=[0.5 / 3600, 1 / 3600, 2 / 3600, 1 / 3600], + orbit_id=["o01", "o01", "o02", "o03", "o04", "o04"], + obs_id=["01", "02", "03", "03", "04", "05"], + distance=[1 / 3600, 0.5 / 3600, 2 / 3600, 1 / 3600, 2 / 3600, 1 / 3600], ) filtered = attributions.drop_coincident_attributions(observations) # Orbit 1 gets linked to two observations at the same time # We should expect to only keep the one with the smallest distance # Orbit 2 and 3 get linked to the same observation but we should keep both - assert len(filtered) == 3 - assert filtered.orbit_id.to_pylist() == ["o01", "o02", "o03"] - assert filtered.obs_id.to_pylist() == ["01", "03", "03"] - assert filtered.distance.to_pylist() == [0.5 / 3600, 2 / 3600, 1 / 3600] + assert len(filtered) == 4 + assert filtered.orbit_id.to_pylist() == ["o01", "o02", "o03", "o04"] + assert filtered.obs_id.to_pylist() == ["02", "03", "03", "05"] + assert filtered.distance.to_pylist() == [0.5 / 3600, 2 / 3600, 1 / 3600, 1 / 3600]