Skip to content

Commit

Permalink
Merge pull request #707 from alexer/fix-eq
Browse files Browse the repository at this point in the history
Fix __eq__ and __ne__ for classes implementing them
  • Loading branch information
gumyr authored Sep 22, 2024
2 parents ca0597d + acbebfb commit 720bee9
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 37 deletions.
51 changes: 21 additions & 30 deletions src/build123d/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,10 @@ def __repr__(self) -> str:

__str__ = __repr__

def __eq__(self, other: Vector) -> bool: # type: ignore[override]
def __eq__(self, other: object) -> bool:
"""Vectors equal operator =="""
if not isinstance(other, Vector):
return NotImplemented
return self.wrapped.IsEqual(other.wrapped, 0.00001, 0.00001)

def __hash__(self) -> int:
Expand Down Expand Up @@ -670,7 +672,7 @@ def __str__(self) -> str:

def __eq__(self, other: object) -> bool:
if not isinstance(other, Axis):
return False
return NotImplemented
return self.position == other.position and self.direction == other.direction

def located(self, new_location: Location):
Expand Down Expand Up @@ -1468,10 +1470,10 @@ def __mul__(self, other: T) -> T:
def __pow__(self, exponent: int) -> Location:
return Location(self.wrapped.Powered(exponent))

def __eq__(self, other: Location) -> bool:
def __eq__(self, other: object) -> bool:
"""Compare Locations"""
if not isinstance(other, Location):
raise ValueError("other must be a Location")
return NotImplemented
quaternion1 = gp_Quaternion()
quaternion1.SetEulerAngles(
gp_EulerSequence.gp_Intrinsic_XYZ,
Expand Down Expand Up @@ -2139,27 +2141,6 @@ def offset(self, amount: float) -> Plane:
origin=self.origin + self.z_dir * amount, x_dir=self.x_dir, z_dir=self.z_dir
)

def _eq_iter(self, other: Plane):
"""Iterator to successively test equality
Args:
other: Plane to compare to
Returns:
Are planes equal
"""
# equality tolerances
eq_tolerance_origin = 1e-6
eq_tolerance_dot = 1e-6

yield isinstance(other, Plane) # comparison is with another Plane
# origins are the same
yield abs(self._origin - other.origin) < eq_tolerance_origin
# z-axis vectors are parallel (assumption: both are unit vectors)
yield abs(self.z_dir.dot(other.z_dir) - 1) < eq_tolerance_dot
# x-axis vectors are parallel (assumption: both are unit vectors)
yield abs(self.x_dir.dot(other.x_dir) - 1) < eq_tolerance_dot

def __copy__(self) -> Plane:
"""Return copy of self"""
return Plane(gp_Pln(self.wrapped.Position()))
Expand All @@ -2168,13 +2149,23 @@ def __deepcopy__(self, _memo) -> Plane:
"""Return deepcopy of self"""
return Plane(gp_Pln(self.wrapped.Position()))

def __eq__(self, other: Plane):
def __eq__(self, other: object):
"""Are planes equal operator =="""
return all(self._eq_iter(other))
if not isinstance(other, Plane):
return NotImplemented

def __ne__(self, other: Plane):
"""Are planes not equal operator !+"""
return not self.__eq__(other)
# equality tolerances
eq_tolerance_origin = 1e-6
eq_tolerance_dot = 1e-6

return (
# origins are the same
abs(self._origin - other.origin) < eq_tolerance_origin
# z-axis vectors are parallel (assumption: both are unit vectors)
and abs(self.z_dir.dot(other.z_dir) - 1) < eq_tolerance_dot
# x-axis vectors are parallel (assumption: both are unit vectors)
and abs(self.x_dir.dot(other.x_dir) - 1) < eq_tolerance_dot
)

def __neg__(self) -> Plane:
"""Reverse z direction of plane operator -"""
Expand Down
12 changes: 9 additions & 3 deletions src/build123d/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -1996,7 +1996,7 @@ def is_equal(self, other: Shape) -> bool:

def __eq__(self, other) -> bool:
"""Are shapes same operator =="""
return self.is_same(other) if isinstance(other, Shape) else False
return self.is_same(other) if isinstance(other, Shape) else NotImplemented

def is_valid(self) -> bool:
"""Returns True if no defect is detected on the shape S or any of its
Expand Down Expand Up @@ -3728,9 +3728,15 @@ def __or__(self, filter_by: Union[Axis, GeomType] = Axis.Z):
"""Filter by axis or geomtype operator |"""
return self.filter_by(filter_by)

def __eq__(self, other: ShapeList):
def __eq__(self, other: object):
"""ShapeLists equality operator =="""
return set(self) == set(other)
return set(self) == set(other) if isinstance(other, ShapeList) else NotImplemented

# Normally implementing __eq__ is enough, but ShapeList subclasses list,
# which already implements __ne__, so we need to override it, too
def __ne__(self, other: ShapeList):
"""ShapeLists inequality operator !="""
return set(self) != set(other) if isinstance(other, ShapeList) else NotImplemented

def __add__(self, other: ShapeList):
"""Combine two ShapeLists together operator +"""
Expand Down
54 changes: 50 additions & 4 deletions tests/test_direct_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@
RAD2DEG = 180 / math.pi


# Always equal to any other object, to test that __eq__ cooperation is working
class AlwaysEqual:
def __eq__(self, other):
return True


class DirectApiTestCase(unittest.TestCase):
def assertTupleAlmostEquals(
self,
Expand Down Expand Up @@ -365,13 +371,13 @@ def test_axis_equal(self):
self.assertEqual(Axis.X, Axis.X)
self.assertEqual(Axis.Y, Axis.Y)
self.assertEqual(Axis.Z, Axis.Z)
self.assertEqual(Axis.X, AlwaysEqual())

def test_axis_not_equal(self):
self.assertNotEqual(Axis.X, Axis.Y)
random_obj = object()
self.assertNotEqual(Axis.X, random_obj)


class TestBoundBox(DirectApiTestCase):
def test_basic_bounding_box(self):
v = Vertex(1, 1, 1)
Expand Down Expand Up @@ -1758,15 +1764,21 @@ def test_to_axis(self):
self.assertVectorAlmostEquals(axis.position, (1, 2, 3), 6)
self.assertVectorAlmostEquals(axis.direction, (0, 1, 0), 6)

def test_eq(self):
def test_equal(self):
loc = Location((1, 2, 3), (4, 5, 6))
diff_position = Location((10, 20, 30), (4, 5, 6))
diff_orientation = Location((1, 2, 3), (40, 50, 60))
same = Location((1, 2, 3), (4, 5, 6))

self.assertEqual(loc, same)
self.assertEqual(loc, AlwaysEqual())

def test_not_equal(self):
loc = Location((1, 2, 3), (40, 50, 60))
diff_position = Location((3, 2, 1), (40, 50, 60))
diff_orientation = Location((1, 2, 3), (60, 50, 40))

self.assertNotEqual(loc, diff_position)
self.assertNotEqual(loc, diff_orientation)
self.assertNotEqual(loc, object())

def test_neg(self):
loc = Location((1, 2, 3), (0, 35, 127))
Expand Down Expand Up @@ -2694,6 +2706,8 @@ def test_plane_equal(self):
Plane(origin=(0, 0, 0), x_dir=(1, 0, 0), z_dir=(0, 1, 1)),
Plane(origin=(0, 0, 0), x_dir=(1, 0, 0), z_dir=(0, 1, 1)),
)
# __eq__ cooperation
self.assertEqual(Plane.XY, AlwaysEqual())

def test_plane_not_equal(self):
# type difference
Expand Down Expand Up @@ -2983,6 +2997,17 @@ def test_is_equal(self):
box = Solid.make_box(1, 1, 1)
self.assertTrue(box.is_equal(box))

def test_equal(self):
box = Solid.make_box(1, 1, 1)
self.assertEqual(box, box)
self.assertEqual(box, AlwaysEqual())

def test_not_equal(self):
box = Solid.make_box(1, 1, 1)
diff = Solid.make_box(1, 2, 3)
self.assertNotEqual(box, diff)
self.assertNotEqual(box, object())

def test_tessellate(self):
box123 = Solid.make_box(1, 2, 3)
verts, triangles = box123.tessellate(1e-6)
Expand Down Expand Up @@ -3492,6 +3517,20 @@ def test_compound(self):
sl = ShapeList([Box(1, 2, 3), Vertex(1, 1, 1)])
self.assertAlmostEqual(sl.compound().volume, 1 * 2 * 3, 5)

def test_equal(self):
box = Box(1, 1, 1)
cyl = Cylinder(1, 1)
sl = ShapeList([box, cyl])
same = ShapeList([cyl, box])
self.assertEqual(sl, same)
self.assertEqual(sl, AlwaysEqual())

def test_not_equal(self):
sl = ShapeList([Box(1, 1, 1), Cylinder(1, 1)])
diff = ShapeList([Box(1, 1, 1), Box(1, 2, 3)])
self.assertNotEqual(sl, diff)
self.assertNotEqual(sl, object())


class TestShells(DirectApiTestCase):
def test_shell_init(self):
Expand Down Expand Up @@ -3806,6 +3845,13 @@ def test_vector_equals(self):
c = Vector(1, 2, 3.000001)
self.assertEqual(a, b)
self.assertEqual(a, c)
self.assertEqual(a, AlwaysEqual())

def test_vector_not_equal(self):
a = Vector(1, 2, 3)
b = Vector(3, 2, 1)
self.assertNotEqual(a, b)
self.assertNotEqual(a, object())

def test_vector_distance(self):
"""
Expand Down

0 comments on commit 720bee9

Please sign in to comment.