Skip to content

Commit

Permalink
Support equality on multidimensional arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
BCSharp committed Nov 30, 2024
1 parent 2bb8d24 commit 0df64c0
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 3 deletions.
34 changes: 31 additions & 3 deletions Src/IronPython/Runtime/Operations/ArrayOps.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,38 @@ public static object __eq__(CodeContext context, Array self, [NotNone] Array oth
if (other is null) throw PythonOps.TypeError("expected Array, got None");

if (self.GetType() != other.GetType()) return ScriptingRuntimeHelpers.False;
// same type implies: same rank, same element type
for (int d = 0; d < self.Rank; d++) {
if (self.GetLowerBound(d) != other.GetLowerBound(d)) return ScriptingRuntimeHelpers.False;
if (self.GetUpperBound(d) != other.GetUpperBound(d)) return ScriptingRuntimeHelpers.False;
}
if (self.Length == 0) return ScriptingRuntimeHelpers.True; // fast track

return ScriptingRuntimeHelpers.BooleanToObject(
((IStructuralEquatable)self).Equals(other, context.LanguageContext.EqualityComparerNonGeneric)
);
if (self.Rank == 1 && self.GetLowerBound(0) == 0 ) {
// IStructuralEquatable.Equals only works for 1-dim, 0-based arrays
return ScriptingRuntimeHelpers.BooleanToObject(
((IStructuralEquatable)self).Equals(other, context.LanguageContext.EqualityComparerNonGeneric)
);
} else {
int[] ix = new int[self.Rank];
for (int d = 0; d < self.Rank; d++) {
ix[d] = self.GetLowerBound(d);
}
for (int i = 0; i < self.Length; i++) {
if (!PythonOps.EqualRetBool(self.GetValue(ix), other.GetValue(ix))) {
return ScriptingRuntimeHelpers.False;
}
for (int d = self.Rank - 1; d >= 0; d--) {
if (ix[d] < self.GetUpperBound(d)) {
ix[d]++;
break;
} else {
ix[d] = self.GetLowerBound(d);
}
}
}
return ScriptingRuntimeHelpers.True;
}
}

[StaticExtensionMethod]
Expand Down
47 changes: 47 additions & 0 deletions Tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,4 +318,51 @@ def test_equality(self):
self.assertTrue(a != l)
self.assertTrue(l != a)

def test_equality_base(self):
a = System.Array.CreateInstance(int, (5,), (5,))
a2 = System.Array.CreateInstance(int, (5,), (5,))
b = System.Array.CreateInstance(int, (6,), (5,))
c = System.Array.CreateInstance(int, (5,), (6,))
d = System.Array.CreateInstance(int, (6,), (6,))

self.assertTrue(a == a2)
self.assertFalse(a == b)
self.assertFalse(a == c)
self.assertFalse(a == d)

self.assertFalse(a != a2)
self.assertTrue(a != b)
self.assertTrue(a != c)
self.assertTrue(a != d)

def test_equality_rank(self):
a = System.Array.CreateInstance(int, 5, 6)
a2 = System.Array.CreateInstance(int, 5, 6)
b = System.Array.CreateInstance(int, 5, 6)
b[0, 0] = 1
c = System.Array.CreateInstance(int, (6, 5), (0, 0))
c[0, 0] = 1
d = System.Array.CreateInstance(int, (6, 5), (1, 1))
d[1, 1] = 1
d1 = System.Array.CreateInstance(int, (6, 5), (1, 1))
d1[1, 1] = 1

self.assertTrue(a == a2)
self.assertFalse(a == b) # different element
self.assertFalse(a == c) # different rank
self.assertFalse(a == d) # different rank
self.assertFalse(b == c) # different shape
self.assertFalse(b == d) # different shape & base
self.assertFalse(c == d) # different base
self.assertTrue(d == d1)

self.assertFalse(a != a2)
self.assertTrue(a != b)
self.assertTrue(a != c)
self.assertTrue(a != d)
self.assertTrue(b != c)
self.assertTrue(b != d)
self.assertTrue(c != d)
self.assertFalse(d != d1)

run_test(__name__)

0 comments on commit 0df64c0

Please sign in to comment.