Skip to content

Commit

Permalink
Since indices may be given as a set for Ranges, and the iteration
Browse files Browse the repository at this point in the history
order of a set is not specified, we should be able to handle that too
(e.g., by returning from `offset_by()` a map from indices to new values,
instead of just a sequence).
  • Loading branch information
pratyai committed Oct 22, 2024
1 parent 32016da commit cffcdd2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
15 changes: 9 additions & 6 deletions dace/subsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,12 +432,14 @@ def data_dims(self):
return (sum(1 if (re - rb + 1) != 1 else 0 for rb, re, _ in self.ranges) + sum(1 if ts != 1 else 0
for ts in self.tile_sizes))

def offset_by(self, off: Collection, negative: bool, indices: Collection):
def offset_by(self, off: Sequence, negative: bool, indices: Collection):
assert all(i < len(self.ranges) for i in indices)
assert all(i < len(off) for i in indices)
mult = -1 if negative else 1
return Range([(self.ranges[i][0] + mult * off[i], self.ranges[i][1] + mult * off[i], self.ranges[i][2])
for i in indices])
return {
i: (self.ranges[i][0] + mult * off[i], self.ranges[i][1] + mult * off[i], self.ranges[i][2])
for i in indices
}

def offset_new(self, other, negative, indices=None):
if not isinstance(other, Subset):
Expand All @@ -447,13 +449,14 @@ def offset_new(self, other, negative, indices=None):
other = Indices([other for _ in self.ranges])
if indices is None:
indices = set(range(len(self.ranges)))
return self.offset_by(other.min_element(), negative, indices)
new_ranges = self.offset_by(other.min_element(), negative, indices)
return Range([new_ranges[i] for i in sorted(indices)])

def offset(self, other, negative, indices=None):
if indices is None:
indices = set(range(len(self.ranges)))
new_ranges = self.offset_new(other, negative, indices).ranges
for i, r in zip(indices, new_ranges):
for i, r in zip(sorted(indices), new_ranges):
self.ranges[i] = r

def dims(self):
Expand Down Expand Up @@ -947,7 +950,7 @@ def strides(self):
def absolute_strides(self, global_shape):
return [1] * len(self.indices)

def offset_by(self, off: Collection, negative: bool):
def offset_by(self, off: Sequence, negative: bool):
assert len(off) <= len(self.indices)
mult = -1 if negative else 1
return Indices([self.indices[i] + mult * off for i, off in enumerate(off)])
Expand Down
14 changes: 7 additions & 7 deletions tests/subsets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,28 @@ def test_range_offset_same_shape(self):
# No offset
off = [0, 0]
rExpect = r0
self.assertEqual(rExpect, r0.offset_by(off, False, [0, 1]))
self.assertEqual({0: rExpect.ranges[0], 1: rExpect.ranges[1]}, r0.offset_by(off, False, [0, 1]))
self.assertEqual(rExpect, r0.offset_new(make_a_range_with_min_elements(off), False, [0, 1]))
self.assertEqual(rExpect, r0.offset_by(off, True, [0, 1]))
self.assertEqual({0: rExpect.ranges[0], 1: rExpect.ranges[1]}, r0.offset_by(off, True, [0, 1]))
self.assertEqual(rExpect, r0.offset_new(make_a_range_with_min_elements(off), True, [0, 1]))

# Positive offset
off = [5, 4]
negative = False
rExpect = subsets.Range([(10, 10 + n - 1, 1), (9, 9 + m - 1, 1)])
self.assertEqual(rExpect, r0.offset_by(off, negative, [0, 1]))
self.assertEqual({0: rExpect.ranges[0], 1: rExpect.ranges[1]}, r0.offset_by(off, negative, [0, 1]))
self.assertEqual(rExpect, r0.offset_new(make_a_range_with_min_elements(off), negative, [0, 1]))
# Only partially
rExpect = subsets.Range([(9, 9 + m - 1, 1)])
partInds = [1]
self.assertEqual(rExpect, r0.offset_by(off, negative, partInds))
self.assertEqual({1: rExpect.ranges[0]}, r0.offset_by(off, negative, partInds))
self.assertEqual(rExpect, r0.offset_new(make_a_range_with_min_elements(off), negative, partInds))

# Negative offset
off = [5, 4]
negative = True
rExpect = subsets.Range([(0, n - 1, 1), (1, 1 + m - 1, 1)])
self.assertEqual(rExpect, r0.offset_by(off, negative, [0, 1]))
self.assertEqual({0: rExpect.ranges[0], 1: rExpect.ranges[1]}, r0.offset_by(off, negative, [0, 1]))
self.assertEqual(rExpect, r0.offset_new(make_a_range_with_min_elements(off), negative, [0, 1]))

def test_range_offset_partial_indices(self):
Expand All @@ -48,12 +48,12 @@ def test_range_offset_partial_indices(self):

partInds = [0]
rExpect = subsets.Range([(10, 10 + n - 1, 1)])
self.assertEqual(rExpect, r0.offset_by(off, False, partInds))
self.assertEqual({0: rExpect.ranges[0]}, r0.offset_by(off, False, partInds))
self.assertEqual(rExpect, r0.offset_new(make_a_range_with_min_elements(off), False, partInds))

partInds = [1]
rExpect = subsets.Range([(9, 9 + m - 1, 1)])
self.assertEqual(rExpect, r0.offset_by(off, False, partInds))
self.assertEqual({1: rExpect.ranges[0]}, r0.offset_by(off, False, partInds))
self.assertEqual(rExpect, r0.offset_new(make_a_range_with_min_elements(off), False, partInds))

def test_range_offset_bad_input(self):
Expand Down

0 comments on commit cffcdd2

Please sign in to comment.