From febc1a67bfba9e8d1df3cd9a152ea7575f06f409 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Mon, 21 Oct 2024 11:10:05 +0200 Subject: [PATCH] Refactor + Add assertions + Tests for offset related functions in `subsets.py`. Currently it just fails with invalid index error for certain cases. Also, the methods are not documented, which is bad because their behaviour is not very consistent. --- dace/subsets.py | 135 +++++++++++++++++++++++++----------------- tests/subsets_test.py | 128 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 209 insertions(+), 54 deletions(-) create mode 100644 tests/subsets_test.py diff --git a/dace/subsets.py b/dace/subsets.py index e7b6869678..831666cc28 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -1,12 +1,15 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import dace.serialize -from dace import data, symbolic, dtypes + import re -import sympy as sp +import warnings from functools import reduce +from typing import List, Optional, Sequence, Set, Union, Collection + +import sympy as sp import sympy.core.sympify -from typing import List, Optional, Sequence, Set, Union -import warnings + +import dace.serialize +from dace import symbolic from dace.config import Config @@ -20,6 +23,7 @@ def nng(expr): except AttributeError: # No free_symbols in expr return expr + def bounding_box_cover_exact(subset_a, subset_b) -> bool: min_elements_a = subset_a.min_element() max_elements_a = subset_a.max_element() @@ -29,8 +33,8 @@ def bounding_box_cover_exact(subset_a, subset_b) -> bool: # Covering only make sense if the two subsets have the same number of dimensions. if len(min_elements_a) != len(min_elements_b): return ValueError( - f"A bounding box of dimensionality {len(min_elements_a)} cannot" - f" test covering a bounding box of dimensionality {len(min_elements_b)}." + f"A bounding box of dimensionality {len(min_elements_a)} cannot" + f" test covering a bounding box of dimensionality {len(min_elements_b)}." ) return all([(symbolic.simplify_ext(nng(rb)) <= symbolic.simplify_ext(nng(orb))) == True @@ -38,7 +42,8 @@ def bounding_box_cover_exact(subset_a, subset_b) -> bool: for rb, re, orb, ore in zip(min_elements_a, max_elements_a, min_elements_b, max_elements_b)]) -def bounding_box_symbolic_positive(subset_a, subset_b, approximation = False)-> bool: + +def bounding_box_symbolic_positive(subset_a, subset_b, approximation=False) -> bool: min_elements_a = subset_a.min_element_approx() if approximation else subset_a.min_element() max_elements_a = subset_a.max_element_approx() if approximation else subset_a.max_element() min_elements_b = subset_b.min_element_approx() if approximation else subset_b.min_element() @@ -47,8 +52,8 @@ def bounding_box_symbolic_positive(subset_a, subset_b, approximation = False)-> # Covering only make sense if the two subsets have the same number of dimensions. if len(min_elements_a) != len(min_elements_b): return ValueError( - f"A bounding box of dimensionality {len(min_elements_a)} cannot" - f" test covering a bounding box of dimensionality {len(min_elements_b)}." + f"A bounding box of dimensionality {len(min_elements_a)} cannot" + f" test covering a bounding box of dimensionality {len(min_elements_b)}." ) for rb, re, orb, ore in zip(min_elements_a, max_elements_a, @@ -70,6 +75,7 @@ def bounding_box_symbolic_positive(subset_a, subset_b, approximation = False)-> return False return True + class Subset(object): """ Defines a subset of a data descriptor. """ @@ -80,7 +86,7 @@ def covers(self, other): # Subsets of different dimensionality can never cover each other. if self.dims() != other.dims(): return ValueError( - f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}" + f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}" ) if not Config.get('optimizer', 'symbolic_positive'): @@ -99,20 +105,22 @@ def covers(self, other): return False return True - + def covers_precise(self, other): """ Returns True if self contains all the elements in other. """ # Subsets of different dimensionality can never cover each other. if self.dims() != other.dims(): return ValueError( - f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}" + f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}" ) # If self does not cover other with a bounding box union, return false. symbolic_positive = Config.get('optimizer', 'symbolic_positive') try: - bounding_box_cover = bounding_box_cover_exact(self, other) if symbolic_positive else bounding_box_symbolic_positive(self, other) + bounding_box_cover = bounding_box_cover_exact(self, + other) if symbolic_positive else bounding_box_symbolic_positive( + self, other) if not bounding_box_cover: return False except TypeError: @@ -151,21 +159,42 @@ def covers_precise(self, other): except: return False return True - # unknown type + # unknown type else: raise TypeError except TypeError: return False - def __repr__(self): return '%s (%s)' % (type(self).__name__, self.__str__()) - def offset(self, other, negative, indices=None): + def offset(self, other, negative: bool, indices=None): + """ + Updates `self` with the minimum and maximum elements (positively or negatively) offset by the minimum elements + of the `other` at indices specified by `indices`, with any other members remaining the same as `self`. + + Available for only _some_ derived classes of `Subset` such as `Range` or `Indices`. The behaviour is also + different depending on the class: + 1. `Range`: If `indices` has a smaller size than `self`, then updates only at those specified indices, and + `self` retains its size. + 2. `Indices`: `indices` must be `None`. `other` must not have a larger dimensions than `self`. Updates only at + the first `len(others)` entries. + """ raise NotImplementedError - def offset_new(self, other, negative, indices=None): + def offset_new(self, other, negative: bool, indices): + """ + Returns a _new_ object of the type `Self`, with the minimum and maximum elements (positively or negatively) + offset by the minimum elements of the `other` at indices specified by `indices`, with any other members + remaining the same as `self`. + + Available for only _some_ derived classes of `Subset` such as `Range` or `Indices`. The behaviour is also + different depending on the class: + 1. `Range`: If `indices` has a smaller size than `self`, then returns a subset with smaller dimensions. + 2. `Indices`: `indices` must be `None`. `other` must not have a larger dimensions than `self`. If `other` has a + smaller number of dimensions, returns a `Indices` with only that many dimensions. + """ raise NotImplementedError def at(self, i, strides): @@ -229,6 +258,7 @@ def _tuple_to_symexpr(val): @dace.serialize.serializable class Range(Subset): """ Subset defined in terms of a fixed range. """ + def __init__(self, ranges): parsed_ranges = [] parsed_tiles = [] @@ -402,19 +432,12 @@ def data_dims(self): return (sum(1 if (re - rb + 1) != 1 else 0 for rb, re, _ in self.ranges) + sum(1 if ts != 1 else 0 for ts in self.tile_sizes)) - def offset(self, other, negative, indices=None): - if not isinstance(other, Subset): - if isinstance(other, (list, tuple)): - other = Indices(other) - else: - other = Indices([other for _ in self.ranges]) + def offset_by(self, off: Collection, negative: bool, indices: Collection): + assert all(i < len(self.ranges) for i in indices) + assert all(i < len(off) for i in indices) mult = -1 if negative else 1 - if indices is None: - indices = set(range(len(self.ranges))) - off = other.min_element() - for i in indices: - rb, re, rs = self.ranges[i] - self.ranges[i] = (rb + mult * off[i], re + mult * off[i], rs) + return Range([(self.ranges[i][0] + mult * off[i], self.ranges[i][1] + mult * off[i], self.ranges[i][2]) + for i in indices]) def offset_new(self, other, negative, indices=None): if not isinstance(other, Subset): @@ -422,12 +445,16 @@ def offset_new(self, other, negative, indices=None): other = Indices(other) else: other = Indices([other for _ in self.ranges]) - mult = -1 if negative else 1 if indices is None: indices = set(range(len(self.ranges))) - off = other.min_element() - return Range([(self.ranges[i][0] + mult * off[i], self.ranges[i][1] + mult * off[i], self.ranges[i][2]) - for i in indices]) + return self.offset_by(other.min_element(), negative, indices) + + def offset(self, other, negative, indices=None): + if indices is None: + indices = set(range(len(self.ranges))) + new_ranges = self.offset_new(other, negative, indices).ranges + for i, r in zip(indices, new_ranges): + self.ranges[i] = r def dims(self): return len(self.ranges) @@ -578,7 +605,7 @@ def from_string(string): value = symbolic.pystr_to_symbolic(uni_dim_tokens[0].strip()) ranges.append((value, value, 1)) continue - #return Range(ranges) + # return Range(ranges) # If dimension has more than 4 tokens, the range is invalid if len(uni_dim_tokens) > 4: raise SyntaxError("Invalid range: {}".format(multi_dim_tokens)) @@ -848,6 +875,7 @@ def intersects(self, other: 'Range'): class Indices(Subset): """ A subset of one element representing a single index in an N-dimensional data descriptor. """ + def __init__(self, indices): if indices is None or len(indices) == 0: raise TypeError('Expected an array of index expressions: got empty' ' array or None') @@ -874,7 +902,7 @@ def from_json(obj, context=None): raise TypeError("from_json of class \"Indices\" called on json " "with type %s (expected 'Indices')" % obj['type']) - #return Indices(symbolic.SymExpr(obj['indices'])) + # return Indices(symbolic.SymExpr(obj['indices'])) return Indices([*map(symbolic.pystr_to_symbolic, obj['indices'])]) def __hash__(self): @@ -919,24 +947,24 @@ def strides(self): def absolute_strides(self, global_shape): return [1] * len(self.indices) - def offset(self, other, negative, indices=None): - if not isinstance(other, Subset): - if isinstance(other, (list, tuple)): - other = Indices(other) - else: - other = Indices([other for _ in self.indices]) + def offset_by(self, off: Collection, negative: bool): + assert len(off) <= len(self.indices) mult = -1 if negative else 1 - for i, off in enumerate(other.min_element()): - self.indices[i] += mult * off + return Indices([self.indices[i] + mult * off for i, off in enumerate(off)]) def offset_new(self, other, negative, indices=None): + assert indices is None if not isinstance(other, Subset): if isinstance(other, (list, tuple)): other = Indices(other) else: other = Indices([other for _ in self.indices]) - mult = -1 if negative else 1 - return Indices([self.indices[i] + mult * off for i, off in enumerate(other.min_element())]) + return self.offset_by(other.min_element(), negative) + + def offset(self, other, negative, indices=None): + new_indices = self.offset_new(other, negative, indices) + for i, new_ind in enumerate(new_indices): + self.indices[i] = new_ind def coord_at(self, i): """ Returns the offseted coordinates of this subset at @@ -1081,6 +1109,7 @@ def intersection(self, other: 'Indices'): return self return None + class SubsetUnion(Subset): """ Wrapper subset type that stores multiple Subsets in a list. @@ -1118,7 +1147,7 @@ def covers(self, other): return False else: return any(s.covers(other) for s in self.subset_list) - + def covers_precise(self, other): """ Returns True if this SubsetUnion covers another @@ -1144,7 +1173,7 @@ def __str__(self): string += " " string += subset.__str__() return string - + def dims(self): if not self.subset_list: return 0 @@ -1168,7 +1197,7 @@ def free_symbols(self) -> Set[str]: for subset in self.subset_list: result |= subset.free_symbols return result - + def replace(self, repl_dict): for subset in self.subset_list: subset.replace(repl_dict) @@ -1178,13 +1207,12 @@ def num_elements(self): min = 0 for subset in self.subset_list: try: - if subset.num_elements() < min or min ==0: + if subset.num_elements() < min or min == 0: min = subset.num_elements() except: continue - - return min + return min def _union_special_cases(arb: symbolic.SymbolicType, brb: symbolic.SymbolicType, are: symbolic.SymbolicType, @@ -1251,8 +1279,6 @@ def bounding_box_union(subset_a: Subset, subset_b: Subset) -> Range: return Range(result) - - def union(subset_a: Subset, subset_b: Subset) -> Subset: """ Compute the union of two Subset objects. If the subsets are not of the same type, degenerates to bounding-box @@ -1321,6 +1347,7 @@ def list_union(subset_a: Subset, subset_b: Subset) -> Subset: except TypeError: return None + def intersects(subset_a: Subset, subset_b: Subset) -> Union[bool, None]: """ Returns True if two subsets intersect, False if they do not, or diff --git a/tests/subsets_test.py b/tests/subsets_test.py new file mode 100644 index 0000000000..d86910b2f4 --- /dev/null +++ b/tests/subsets_test.py @@ -0,0 +1,128 @@ +import unittest +from typing import Collection + +import dace +from dace import subsets + + +def make_a_range_with_min_elements(min_elems: Collection): + return subsets.Range([(e, e, 1) for e in min_elems]) + + +class TestOffsetNew(unittest.TestCase): + def test_range_offset_same_shape(self): + n, m = dace.symbol('n', dtype=dace.int32, positive=True), dace.symbol('m', dtype=dace.int32, positive=True) + r0 = subsets.Range([(5, 5 + n - 1, 1), (5, 5 + m - 1, 1)]) + + # No offset + off = [0, 0] + rExpect = r0 + self.assertEqual(rExpect, r0.offset_by(off, False, [0, 1])) + self.assertEqual(rExpect, r0.offset_new(make_a_range_with_min_elements(off), False, [0, 1])) + self.assertEqual(rExpect, r0.offset_by(off, True, [0, 1])) + self.assertEqual(rExpect, r0.offset_new(make_a_range_with_min_elements(off), True, [0, 1])) + + # Positive offset + off = [5, 4] + negative = False + rExpect = subsets.Range([(10, 10 + n - 1, 1), (9, 9 + m - 1, 1)]) + self.assertEqual(rExpect, r0.offset_by(off, negative, [0, 1])) + self.assertEqual(rExpect, r0.offset_new(make_a_range_with_min_elements(off), negative, [0, 1])) + # Only partially + rExpect = subsets.Range([(9, 9 + m - 1, 1)]) + partInds = [1] + self.assertEqual(rExpect, r0.offset_by(off, negative, partInds)) + self.assertEqual(rExpect, r0.offset_new(make_a_range_with_min_elements(off), negative, partInds)) + + # Negative offset + off = [5, 4] + negative = True + rExpect = subsets.Range([(0, n - 1, 1), (1, 1 + m - 1, 1)]) + self.assertEqual(rExpect, r0.offset_by(off, negative, [0, 1])) + self.assertEqual(rExpect, r0.offset_new(make_a_range_with_min_elements(off), negative, [0, 1])) + + def test_range_offset_partial_indices(self): + n, m = dace.symbol('n', dtype=dace.int32, positive=True), dace.symbol('m', dtype=dace.int32, positive=True) + r0 = subsets.Range([(5, 5 + n - 1, 1), (5, 5 + m - 1, 1)]) + off = [5, 4] + + partInds = [0] + rExpect = subsets.Range([(10, 10 + n - 1, 1)]) + self.assertEqual(rExpect, r0.offset_by(off, False, partInds)) + self.assertEqual(rExpect, r0.offset_new(make_a_range_with_min_elements(off), False, partInds)) + + partInds = [1] + rExpect = subsets.Range([(9, 9 + m - 1, 1)]) + self.assertEqual(rExpect, r0.offset_by(off, False, partInds)) + self.assertEqual(rExpect, r0.offset_new(make_a_range_with_min_elements(off), False, partInds)) + + def test_range_offset_bad_input(self): + n, m = dace.symbol('n', dtype=dace.int32, positive=True), dace.symbol('m', dtype=dace.int32, positive=True) + r0 = subsets.Range([(5, 5 + n - 1, 1), (5, 5 + m - 1, 1)]) + + # Offset list is too short + off = [5] + inds = [1] + with self.assertRaises(AssertionError): + r0.offset_by(off, False, inds) + with self.assertRaises(AssertionError): + r0.offset_new(make_a_range_with_min_elements(off), False, inds) + + # Index out of bounds. + off = [5, 4] + inds = [0, 1001] + with self.assertRaises(AssertionError): + r0.offset_by(off, False, inds) + with self.assertRaises(AssertionError): + r0.offset_new(make_a_range_with_min_elements(off), False, inds) + + def test_indices_offset_same_shape(self): + n, m = dace.symbol('n', dtype=dace.int32, positive=True), dace.symbol('m', dtype=dace.int32, positive=True) + ind0 = subsets.Indices([n, m]) + + # No offset + off = [0, 0] + indExpect = ind0 + self.assertEqual(indExpect, ind0.offset_by(off, False)) + self.assertEqual(indExpect, ind0.offset_new(make_a_range_with_min_elements(off), False)) + self.assertEqual(indExpect, ind0.offset_by(off, True)) + self.assertEqual(indExpect, ind0.offset_new(make_a_range_with_min_elements(off), True)) + + # Positive offset + off = [5, 4] + negative = False + indExpect = subsets.Indices([n + 5, m + 4]) + self.assertEqual(indExpect, ind0.offset_by(off, negative)) + self.assertEqual(indExpect, ind0.offset_new(make_a_range_with_min_elements(off), negative)) + + # Negative offset + off = [5, 4] + negative = True + indExpect = subsets.Indices([n - 5, m - 4]) + self.assertEqual(indExpect, ind0.offset_by(off, negative)) + self.assertEqual(indExpect, ind0.offset_new(make_a_range_with_min_elements(off), negative)) + + def test_indices_offset_smaller_dims(self): + n, m = dace.symbol('n', dtype=dace.int32, positive=True), dace.symbol('m', dtype=dace.int32, positive=True) + ind0 = subsets.Indices([n, m]) + + # Offset size too small + off = [5] + indExpect = subsets.Indices([n + 5]) + self.assertEqual(indExpect, ind0.offset_by(off, False)) + self.assertEqual(indExpect, ind0.offset_new(make_a_range_with_min_elements(off), False)) + + def test_indices_offset_bad_input(self): + n, m = dace.symbol('n', dtype=dace.int32, positive=True), dace.symbol('m', dtype=dace.int32, positive=True) + ind0 = subsets.Indices([n, m]) + + # Offset size too big + off = [5, 4, 3] + with self.assertRaises(AssertionError): + ind0.offset_by(off, False) + with self.assertRaises(AssertionError): + ind0.offset_new(make_a_range_with_min_elements(off), False) + + +if __name__ == '__main__': + unittest.main()