diff --git a/dace/frontend/python/memlet_parser.py b/dace/frontend/python/memlet_parser.py index 9bd051be5c..a95bf82046 100644 --- a/dace/frontend/python/memlet_parser.py +++ b/dace/frontend/python/memlet_parser.py @@ -16,6 +16,22 @@ MemletType = Union[ast.Call, ast.Attribute, ast.Subscript, ast.Name] +if sys.version_info < (3, 8): + _simple_ast_nodes = (ast.Constant, ast.Name, ast.NameConstant, ast.Num) + BytesConstant = ast.Bytes + EllipsisConstant = ast.Ellipsis + NameConstant = ast.NameConstant + NumConstant = ast.Num + StrConstant = ast.Str +else: + _simple_ast_nodes = (ast.Constant, ast.Name) + BytesConstant = ast.Constant + EllipsisConstant = ast.Constant + NameConstant = ast.Constant + NumConstant = ast.Constant + StrConstant = ast.Constant + + @dataclass class MemletExpr: name: str @@ -125,7 +141,7 @@ def _fill_missing_slices(das, ast_ndslice, array, indices): ndslice[j] = (0, array.shape[j] - 1, 1) idx += 1 new_idx += 1 - elif (dim is None or (isinstance(dim, (ast.Constant, ast.NameConstant)) and dim.value is None)): + elif (dim is None or (isinstance(dim, (ast.Constant, NameConstant)) and dim.value is None)): new_axes.append(new_idx) new_idx += 1 # NOTE: Do not increment idx here