Skip to content

Commit

Permalink
Refactor + Add assertions + Tests for offset related functions in
Browse files Browse the repository at this point in the history
`subsets.py`.

Currently it just fails with invalid index error for certain cases.
Also, the methods are not documented, which is bad because their
behaviour is not very consistent.
  • Loading branch information
pratyai committed Oct 21, 2024
1 parent 975a065 commit febc1a6
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 54 deletions.
135 changes: 81 additions & 54 deletions dace/subsets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
import dace.serialize
from dace import data, symbolic, dtypes

import re
import sympy as sp
import warnings
from functools import reduce
from typing import List, Optional, Sequence, Set, Union, Collection

import sympy as sp
import sympy.core.sympify
from typing import List, Optional, Sequence, Set, Union
import warnings

import dace.serialize
from dace import symbolic
from dace.config import Config


Expand All @@ -20,6 +23,7 @@ def nng(expr):
except AttributeError: # No free_symbols in expr
return expr


def bounding_box_cover_exact(subset_a, subset_b) -> bool:
min_elements_a = subset_a.min_element()
max_elements_a = subset_a.max_element()
Expand All @@ -29,16 +33,17 @@ def bounding_box_cover_exact(subset_a, subset_b) -> bool:
# Covering only make sense if the two subsets have the same number of dimensions.
if len(min_elements_a) != len(min_elements_b):
return ValueError(
f"A bounding box of dimensionality {len(min_elements_a)} cannot"
f" test covering a bounding box of dimensionality {len(min_elements_b)}."
f"A bounding box of dimensionality {len(min_elements_a)} cannot"
f" test covering a bounding box of dimensionality {len(min_elements_b)}."
)

return all([(symbolic.simplify_ext(nng(rb)) <= symbolic.simplify_ext(nng(orb))) == True
and (symbolic.simplify_ext(nng(re)) >= symbolic.simplify_ext(nng(ore))) == True
for rb, re, orb, ore in zip(min_elements_a, max_elements_a,
min_elements_b, max_elements_b)])

def bounding_box_symbolic_positive(subset_a, subset_b, approximation = False)-> bool:

def bounding_box_symbolic_positive(subset_a, subset_b, approximation=False) -> bool:
min_elements_a = subset_a.min_element_approx() if approximation else subset_a.min_element()
max_elements_a = subset_a.max_element_approx() if approximation else subset_a.max_element()
min_elements_b = subset_b.min_element_approx() if approximation else subset_b.min_element()
Expand All @@ -47,8 +52,8 @@ def bounding_box_symbolic_positive(subset_a, subset_b, approximation = False)->
# Covering only make sense if the two subsets have the same number of dimensions.
if len(min_elements_a) != len(min_elements_b):
return ValueError(
f"A bounding box of dimensionality {len(min_elements_a)} cannot"
f" test covering a bounding box of dimensionality {len(min_elements_b)}."
f"A bounding box of dimensionality {len(min_elements_a)} cannot"
f" test covering a bounding box of dimensionality {len(min_elements_b)}."
)

for rb, re, orb, ore in zip(min_elements_a, max_elements_a,
Expand All @@ -70,6 +75,7 @@ def bounding_box_symbolic_positive(subset_a, subset_b, approximation = False)->
return False
return True


class Subset(object):
""" Defines a subset of a data descriptor. """

Expand All @@ -80,7 +86,7 @@ def covers(self, other):
# Subsets of different dimensionality can never cover each other.
if self.dims() != other.dims():
return ValueError(
f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}"
f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}"
)

if not Config.get('optimizer', 'symbolic_positive'):
Expand All @@ -99,20 +105,22 @@ def covers(self, other):
return False

return True

def covers_precise(self, other):
""" Returns True if self contains all the elements in other. """

# Subsets of different dimensionality can never cover each other.
if self.dims() != other.dims():
return ValueError(
f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}"
f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}"
)

# If self does not cover other with a bounding box union, return false.
symbolic_positive = Config.get('optimizer', 'symbolic_positive')
try:
bounding_box_cover = bounding_box_cover_exact(self, other) if symbolic_positive else bounding_box_symbolic_positive(self, other)
bounding_box_cover = bounding_box_cover_exact(self,
other) if symbolic_positive else bounding_box_symbolic_positive(
self, other)
if not bounding_box_cover:
return False
except TypeError:
Expand Down Expand Up @@ -151,21 +159,42 @@ def covers_precise(self, other):
except:
return False
return True
# unknown type
# unknown type
else:
raise TypeError

except TypeError:
return False


def __repr__(self):
return '%s (%s)' % (type(self).__name__, self.__str__())

def offset(self, other, negative, indices=None):
def offset(self, other, negative: bool, indices=None):
"""
Updates `self` with the minimum and maximum elements (positively or negatively) offset by the minimum elements
of the `other` at indices specified by `indices`, with any other members remaining the same as `self`.
Available for only _some_ derived classes of `Subset` such as `Range` or `Indices`. The behaviour is also
different depending on the class:
1. `Range`: If `indices` has a smaller size than `self`, then updates only at those specified indices, and
`self` retains its size.
2. `Indices`: `indices` must be `None`. `other` must not have a larger dimensions than `self`. Updates only at
the first `len(others)` entries.
"""
raise NotImplementedError

def offset_new(self, other, negative, indices=None):
def offset_new(self, other, negative: bool, indices):
"""
Returns a _new_ object of the type `Self`, with the minimum and maximum elements (positively or negatively)
offset by the minimum elements of the `other` at indices specified by `indices`, with any other members
remaining the same as `self`.
Available for only _some_ derived classes of `Subset` such as `Range` or `Indices`. The behaviour is also
different depending on the class:
1. `Range`: If `indices` has a smaller size than `self`, then returns a subset with smaller dimensions.
2. `Indices`: `indices` must be `None`. `other` must not have a larger dimensions than `self`. If `other` has a
smaller number of dimensions, returns a `Indices` with only that many dimensions.
"""
raise NotImplementedError

def at(self, i, strides):
Expand Down Expand Up @@ -229,6 +258,7 @@ def _tuple_to_symexpr(val):
@dace.serialize.serializable
class Range(Subset):
""" Subset defined in terms of a fixed range. """

def __init__(self, ranges):
parsed_ranges = []
parsed_tiles = []
Expand Down Expand Up @@ -402,32 +432,29 @@ def data_dims(self):
return (sum(1 if (re - rb + 1) != 1 else 0 for rb, re, _ in self.ranges) + sum(1 if ts != 1 else 0
for ts in self.tile_sizes))

def offset(self, other, negative, indices=None):
if not isinstance(other, Subset):
if isinstance(other, (list, tuple)):
other = Indices(other)
else:
other = Indices([other for _ in self.ranges])
def offset_by(self, off: Collection, negative: bool, indices: Collection):
assert all(i < len(self.ranges) for i in indices)
assert all(i < len(off) for i in indices)
mult = -1 if negative else 1
if indices is None:
indices = set(range(len(self.ranges)))
off = other.min_element()
for i in indices:
rb, re, rs = self.ranges[i]
self.ranges[i] = (rb + mult * off[i], re + mult * off[i], rs)
return Range([(self.ranges[i][0] + mult * off[i], self.ranges[i][1] + mult * off[i], self.ranges[i][2])
for i in indices])

def offset_new(self, other, negative, indices=None):
if not isinstance(other, Subset):
if isinstance(other, (list, tuple)):
other = Indices(other)
else:
other = Indices([other for _ in self.ranges])
mult = -1 if negative else 1
if indices is None:
indices = set(range(len(self.ranges)))
off = other.min_element()
return Range([(self.ranges[i][0] + mult * off[i], self.ranges[i][1] + mult * off[i], self.ranges[i][2])
for i in indices])
return self.offset_by(other.min_element(), negative, indices)

def offset(self, other, negative, indices=None):
if indices is None:
indices = set(range(len(self.ranges)))
new_ranges = self.offset_new(other, negative, indices).ranges
for i, r in zip(indices, new_ranges):
self.ranges[i] = r

def dims(self):
return len(self.ranges)
Expand Down Expand Up @@ -578,7 +605,7 @@ def from_string(string):
value = symbolic.pystr_to_symbolic(uni_dim_tokens[0].strip())
ranges.append((value, value, 1))
continue
#return Range(ranges)
# return Range(ranges)
# If dimension has more than 4 tokens, the range is invalid
if len(uni_dim_tokens) > 4:
raise SyntaxError("Invalid range: {}".format(multi_dim_tokens))
Expand Down Expand Up @@ -848,6 +875,7 @@ def intersects(self, other: 'Range'):
class Indices(Subset):
""" A subset of one element representing a single index in an
N-dimensional data descriptor. """

def __init__(self, indices):
if indices is None or len(indices) == 0:
raise TypeError('Expected an array of index expressions: got empty' ' array or None')
Expand All @@ -874,7 +902,7 @@ def from_json(obj, context=None):
raise TypeError("from_json of class \"Indices\" called on json "
"with type %s (expected 'Indices')" % obj['type'])

#return Indices(symbolic.SymExpr(obj['indices']))
# return Indices(symbolic.SymExpr(obj['indices']))
return Indices([*map(symbolic.pystr_to_symbolic, obj['indices'])])

def __hash__(self):
Expand Down Expand Up @@ -919,24 +947,24 @@ def strides(self):
def absolute_strides(self, global_shape):
return [1] * len(self.indices)

def offset(self, other, negative, indices=None):
if not isinstance(other, Subset):
if isinstance(other, (list, tuple)):
other = Indices(other)
else:
other = Indices([other for _ in self.indices])
def offset_by(self, off: Collection, negative: bool):
assert len(off) <= len(self.indices)
mult = -1 if negative else 1
for i, off in enumerate(other.min_element()):
self.indices[i] += mult * off
return Indices([self.indices[i] + mult * off for i, off in enumerate(off)])

def offset_new(self, other, negative, indices=None):
assert indices is None
if not isinstance(other, Subset):
if isinstance(other, (list, tuple)):
other = Indices(other)
else:
other = Indices([other for _ in self.indices])
mult = -1 if negative else 1
return Indices([self.indices[i] + mult * off for i, off in enumerate(other.min_element())])
return self.offset_by(other.min_element(), negative)

def offset(self, other, negative, indices=None):
new_indices = self.offset_new(other, negative, indices)
for i, new_ind in enumerate(new_indices):
self.indices[i] = new_ind

def coord_at(self, i):
""" Returns the offseted coordinates of this subset at
Expand Down Expand Up @@ -1081,6 +1109,7 @@ def intersection(self, other: 'Indices'):
return self
return None


class SubsetUnion(Subset):
"""
Wrapper subset type that stores multiple Subsets in a list.
Expand Down Expand Up @@ -1118,7 +1147,7 @@ def covers(self, other):
return False
else:
return any(s.covers(other) for s in self.subset_list)

def covers_precise(self, other):
"""
Returns True if this SubsetUnion covers another
Expand All @@ -1144,7 +1173,7 @@ def __str__(self):
string += " "
string += subset.__str__()
return string

def dims(self):
if not self.subset_list:
return 0
Expand All @@ -1168,7 +1197,7 @@ def free_symbols(self) -> Set[str]:
for subset in self.subset_list:
result |= subset.free_symbols
return result

def replace(self, repl_dict):
for subset in self.subset_list:
subset.replace(repl_dict)
Expand All @@ -1178,13 +1207,12 @@ def num_elements(self):
min = 0
for subset in self.subset_list:
try:
if subset.num_elements() < min or min ==0:
if subset.num_elements() < min or min == 0:
min = subset.num_elements()
except:
continue

return min

return min


def _union_special_cases(arb: symbolic.SymbolicType, brb: symbolic.SymbolicType, are: symbolic.SymbolicType,
Expand Down Expand Up @@ -1251,8 +1279,6 @@ def bounding_box_union(subset_a: Subset, subset_b: Subset) -> Range:
return Range(result)




def union(subset_a: Subset, subset_b: Subset) -> Subset:
""" Compute the union of two Subset objects.
If the subsets are not of the same type, degenerates to bounding-box
Expand Down Expand Up @@ -1321,6 +1347,7 @@ def list_union(subset_a: Subset, subset_b: Subset) -> Subset:
except TypeError:
return None


def intersects(subset_a: Subset, subset_b: Subset) -> Union[bool, None]:
"""
Returns True if two subsets intersect, False if they do not, or
Expand Down
Loading

0 comments on commit febc1a6

Please sign in to comment.