diff --git a/src/upstage_des/data_types.py b/src/upstage_des/data_types.py index 3554f6b..3ebc0cc 100644 --- a/src/upstage_des/data_types.py +++ b/src/upstage_des/data_types.py @@ -557,6 +557,15 @@ def make_location(self) -> CartesianLocation: x=self.x, y=self.y, z=self.z, use_altitude_units=self.use_altitude_unts ) + def __eq__(self, value: Any) -> bool: + """Test for equality of two cartesian locations data objects.""" + if not isinstance(value, CartesianLocationData): + raise ValueError( + f"Cannot compare a {value.__class__.__name__} to a CartesianLocationData" + ) + check = ["x", "y", "z"] + return all(getattr(self, c) == getattr(value, c) for c in check) + class GeodeticLocationData: """Object for storing geodetic data without an environment.""" @@ -596,3 +605,18 @@ def make_location(self) -> GeodeticLocation: alt=self.alt, in_radians=self.in_radians, ) + + def __eq__(self, value: Any) -> bool: + """Test for equality of two cartesian locations data objects.""" + if not isinstance(value, GeodeticLocationData): + raise ValueError( + f"Cannot compare a {value.__class__.__name__} to a GeodeticLocationData" + ) + other = [value.lat, value.lon] + if value.in_radians != self.in_radians: + if self.in_radians: + other = list(map(radians, other)) + else: + other = list(map(degrees, other)) + angles = [self.lat, self.lon] == other + return angles and self.alt == value.alt diff --git a/src/upstage_des/test/test_data_types.py b/src/upstage_des/test/test_data_types.py index 92e0b22..a1f1282 100644 --- a/src/upstage_des/test/test_data_types.py +++ b/src/upstage_des/test/test_data_types.py @@ -105,3 +105,39 @@ def test_geodetic() -> None: assert loc_up_rad - UP.GeodeticLocation(10, 10) > 0 assert UP.GeodeticLocation(10, 10) - loc_up_rad > 0 + + +def test_data_objects() -> None: + cart1 = UP.CartesianLocationData(1.0, 2.1, 3.2) + cart2 = UP.CartesianLocationData(1.0, 2.1, 3.2) + assert cart1 == cart2 + + with pytest.raises(ValueError): + cart1 == (1.0, 2.1, 3.2) + + geo1 = UP.GeodeticLocationData(13.0, 12.1, 11.2) + geo2 = UP.GeodeticLocationData(13.0, 12.1, 11.2) + assert geo1 == geo2 + + geo3 = UP.GeodeticLocationData(radians(13.0), radians(12.1), 11.2, in_radians=True) + assert geo1 == geo3 + assert geo3 == geo1 + + with pytest.raises(ValueError): + geo1 == (13.0, 12.1, 11.2) + + with UP.EnvironmentContext(): + for k, v in STAGE_SETUP.items(): + UP.add_stage_variable(k, v) + + loc1 = geo1.make_location() + assert isinstance(loc1, UP.GeodeticLocation) + + loc3 = geo3.make_location() + + assert loc1 - loc3 == 0.0 + assert loc1 == loc3 + + loc1 = cart1.make_location() + loc2 = cart2.make_location() + assert loc1 == loc2