From acbebfb017c58103911c48d3ca3dec2b4d565e8a Mon Sep 17 00:00:00 2001 From: Aleksi Torhamo Date: Thu, 19 Sep 2024 00:26:44 +0300 Subject: [PATCH] Fix __eq__ and __ne__ for classes implementing them Main issue, which concerns Vector, Location, and ShapeList: Comparison with an object of a different type should not cause an exception - they are simply not equal. Raising an exception in __eq__ can (and will*) break unrelated code that expects __eq__ to be well-behaved. (* I noticed this bug when cq-editor choked on it while trying to find a name for an object in a dictionary of local variables) There's a second more minor issue, which concerns the rest of the classes: When the other type in __eq__ is not supported, one should technically return NotImplemented instead of False, to allow the other type to take part in the comparison, in case they know about our type. (__ne__ should also not generally be implemented as just the negation of __eq__ because of this, but that's also a moot point because the __ne__ can just be removed - Python will automatically do the right thing based on __eq__ here) Technically, the __eq__ for Vector and Plane is also broken in another way: It's not transitive. >>> a, b, c = Vector(0), Vector(9e-6), Vector(18e-6) >>> a == b == c True >>> a == c False They should really eg. have a separate is_close() for approximate comparison, but this isn't fixed here, since I have no idea how many places it'd break, for one. --- src/build123d/geometry.py | 51 +++++++++++++++--------------------- src/build123d/topology.py | 12 ++++++--- tests/test_direct_api.py | 54 ++++++++++++++++++++++++++++++++++++--- 3 files changed, 80 insertions(+), 37 deletions(-) diff --git a/src/build123d/geometry.py b/src/build123d/geometry.py index 185dc720..009b480b 100644 --- a/src/build123d/geometry.py +++ b/src/build123d/geometry.py @@ -444,8 +444,10 @@ def __repr__(self) -> str: __str__ = __repr__ - def __eq__(self, other: Vector) -> bool: # type: ignore[override] + def __eq__(self, other: object) -> bool: """Vectors equal operator ==""" + if not isinstance(other, Vector): + return NotImplemented return self.wrapped.IsEqual(other.wrapped, 0.00001, 0.00001) def __hash__(self) -> int: @@ -670,7 +672,7 @@ def __str__(self) -> str: def __eq__(self, other: object) -> bool: if not isinstance(other, Axis): - return False + return NotImplemented return self.position == other.position and self.direction == other.direction def located(self, new_location: Location): @@ -1468,10 +1470,10 @@ def __mul__(self, other: T) -> T: def __pow__(self, exponent: int) -> Location: return Location(self.wrapped.Powered(exponent)) - def __eq__(self, other: Location) -> bool: + def __eq__(self, other: object) -> bool: """Compare Locations""" if not isinstance(other, Location): - raise ValueError("other must be a Location") + return NotImplemented quaternion1 = gp_Quaternion() quaternion1.SetEulerAngles( gp_EulerSequence.gp_Intrinsic_XYZ, @@ -2139,27 +2141,6 @@ def offset(self, amount: float) -> Plane: origin=self.origin + self.z_dir * amount, x_dir=self.x_dir, z_dir=self.z_dir ) - def _eq_iter(self, other: Plane): - """Iterator to successively test equality - - Args: - other: Plane to compare to - - Returns: - Are planes equal - """ - # equality tolerances - eq_tolerance_origin = 1e-6 - eq_tolerance_dot = 1e-6 - - yield isinstance(other, Plane) # comparison is with another Plane - # origins are the same - yield abs(self._origin - other.origin) < eq_tolerance_origin - # z-axis vectors are parallel (assumption: both are unit vectors) - yield abs(self.z_dir.dot(other.z_dir) - 1) < eq_tolerance_dot - # x-axis vectors are parallel (assumption: both are unit vectors) - yield abs(self.x_dir.dot(other.x_dir) - 1) < eq_tolerance_dot - def __copy__(self) -> Plane: """Return copy of self""" return Plane(gp_Pln(self.wrapped.Position())) @@ -2168,13 +2149,23 @@ def __deepcopy__(self, _memo) -> Plane: """Return deepcopy of self""" return Plane(gp_Pln(self.wrapped.Position())) - def __eq__(self, other: Plane): + def __eq__(self, other: object): """Are planes equal operator ==""" - return all(self._eq_iter(other)) + if not isinstance(other, Plane): + return NotImplemented - def __ne__(self, other: Plane): - """Are planes not equal operator !+""" - return not self.__eq__(other) + # equality tolerances + eq_tolerance_origin = 1e-6 + eq_tolerance_dot = 1e-6 + + return ( + # origins are the same + abs(self._origin - other.origin) < eq_tolerance_origin + # z-axis vectors are parallel (assumption: both are unit vectors) + and abs(self.z_dir.dot(other.z_dir) - 1) < eq_tolerance_dot + # x-axis vectors are parallel (assumption: both are unit vectors) + and abs(self.x_dir.dot(other.x_dir) - 1) < eq_tolerance_dot + ) def __neg__(self) -> Plane: """Reverse z direction of plane operator -""" diff --git a/src/build123d/topology.py b/src/build123d/topology.py index 5e4f507c..269b43a0 100644 --- a/src/build123d/topology.py +++ b/src/build123d/topology.py @@ -1972,7 +1972,7 @@ def is_equal(self, other: Shape) -> bool: def __eq__(self, other) -> bool: """Are shapes same operator ==""" - return self.is_same(other) if isinstance(other, Shape) else False + return self.is_same(other) if isinstance(other, Shape) else NotImplemented def is_valid(self) -> bool: """Returns True if no defect is detected on the shape S or any of its @@ -3704,9 +3704,15 @@ def __or__(self, filter_by: Union[Axis, GeomType] = Axis.Z): """Filter by axis or geomtype operator |""" return self.filter_by(filter_by) - def __eq__(self, other: ShapeList): + def __eq__(self, other: object): """ShapeLists equality operator ==""" - return set(self) == set(other) + return set(self) == set(other) if isinstance(other, ShapeList) else NotImplemented + + # Normally implementing __eq__ is enough, but ShapeList subclasses list, + # which already implements __ne__, so we need to override it, too + def __ne__(self, other: ShapeList): + """ShapeLists inequality operator !=""" + return set(self) != set(other) if isinstance(other, ShapeList) else NotImplemented def __add__(self, other: ShapeList): """Combine two ShapeLists together operator +""" diff --git a/tests/test_direct_api.py b/tests/test_direct_api.py index dd027232..b6e7aab1 100644 --- a/tests/test_direct_api.py +++ b/tests/test_direct_api.py @@ -93,6 +93,12 @@ RAD2DEG = 180 / math.pi +# Always equal to any other object, to test that __eq__ cooperation is working +class AlwaysEqual: + def __eq__(self, other): + return True + + class DirectApiTestCase(unittest.TestCase): def assertTupleAlmostEquals( self, @@ -363,13 +369,13 @@ def test_axis_equal(self): self.assertEqual(Axis.X, Axis.X) self.assertEqual(Axis.Y, Axis.Y) self.assertEqual(Axis.Z, Axis.Z) + self.assertEqual(Axis.X, AlwaysEqual()) def test_axis_not_equal(self): self.assertNotEqual(Axis.X, Axis.Y) random_obj = object() self.assertNotEqual(Axis.X, random_obj) - class TestBoundBox(DirectApiTestCase): def test_basic_bounding_box(self): v = Vertex(1, 1, 1) @@ -1730,15 +1736,21 @@ def test_to_axis(self): self.assertVectorAlmostEquals(axis.position, (1, 2, 3), 6) self.assertVectorAlmostEquals(axis.direction, (0, 1, 0), 6) - def test_eq(self): + def test_equal(self): loc = Location((1, 2, 3), (4, 5, 6)) - diff_position = Location((10, 20, 30), (4, 5, 6)) - diff_orientation = Location((1, 2, 3), (40, 50, 60)) same = Location((1, 2, 3), (4, 5, 6)) self.assertEqual(loc, same) + self.assertEqual(loc, AlwaysEqual()) + + def test_not_equal(self): + loc = Location((1, 2, 3), (40, 50, 60)) + diff_position = Location((3, 2, 1), (40, 50, 60)) + diff_orientation = Location((1, 2, 3), (60, 50, 40)) + self.assertNotEqual(loc, diff_position) self.assertNotEqual(loc, diff_orientation) + self.assertNotEqual(loc, object()) def test_neg(self): loc = Location((1, 2, 3), (0, 35, 127)) @@ -2666,6 +2678,8 @@ def test_plane_equal(self): Plane(origin=(0, 0, 0), x_dir=(1, 0, 0), z_dir=(0, 1, 1)), Plane(origin=(0, 0, 0), x_dir=(1, 0, 0), z_dir=(0, 1, 1)), ) + # __eq__ cooperation + self.assertEqual(Plane.XY, AlwaysEqual()) def test_plane_not_equal(self): # type difference @@ -2955,6 +2969,17 @@ def test_is_equal(self): box = Solid.make_box(1, 1, 1) self.assertTrue(box.is_equal(box)) + def test_equal(self): + box = Solid.make_box(1, 1, 1) + self.assertEqual(box, box) + self.assertEqual(box, AlwaysEqual()) + + def test_not_equal(self): + box = Solid.make_box(1, 1, 1) + diff = Solid.make_box(1, 2, 3) + self.assertNotEqual(box, diff) + self.assertNotEqual(box, object()) + def test_tessellate(self): box123 = Solid.make_box(1, 2, 3) verts, triangles = box123.tessellate(1e-6) @@ -3439,6 +3464,20 @@ def test_compound(self): sl = ShapeList([Box(1, 2, 3), Vertex(1, 1, 1)]) self.assertAlmostEqual(sl.compound().volume, 1 * 2 * 3, 5) + def test_equal(self): + box = Box(1, 1, 1) + cyl = Cylinder(1, 1) + sl = ShapeList([box, cyl]) + same = ShapeList([cyl, box]) + self.assertEqual(sl, same) + self.assertEqual(sl, AlwaysEqual()) + + def test_not_equal(self): + sl = ShapeList([Box(1, 1, 1), Cylinder(1, 1)]) + diff = ShapeList([Box(1, 1, 1), Box(1, 2, 3)]) + self.assertNotEqual(sl, diff) + self.assertNotEqual(sl, object()) + class TestShells(DirectApiTestCase): def test_shell_init(self): @@ -3753,6 +3792,13 @@ def test_vector_equals(self): c = Vector(1, 2, 3.000001) self.assertEqual(a, b) self.assertEqual(a, c) + self.assertEqual(a, AlwaysEqual()) + + def test_vector_not_equal(self): + a = Vector(1, 2, 3) + b = Vector(3, 2, 1) + self.assertNotEqual(a, b) + self.assertNotEqual(a, object()) def test_vector_distance(self): """