Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add offset normalization to Fortran frontend #1367

Merged
merged 24 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2420440
Support in Fortran frontend arrays with offset declaration
mcopik Jul 24, 2023
63b074b
Support shape attribute specification in the Fortran frontend
mcopik Jul 24, 2023
e1b4399
Rename array attributes test
mcopik Jul 24, 2023
37fa580
Remove old code
mcopik Aug 14, 2023
b9e9f61
Fix handling of non-dimensional attributes in Fortran frontend
mcopik Aug 14, 2023
5f212e7
Merge branch 'master' into fortran_frontend_array_dimensions
acalotoiu Aug 21, 2023
427f467
Add Fortran AST transformation assigning to each node its parent scope
mcopik Sep 8, 2023
0d19df2
Add new Fortran parser function to export pure AST, not SDFG
mcopik Sep 8, 2023
ab7930d
Merge branch 'fortran_frontend_array_dimensions' into fortran_ast_par…
mcopik Sep 8, 2023
5cfbed3
Add Fortran AST pass to gather all variable declarations inside a scope
mcopik Sep 8, 2023
2296556
First implementation of the offset normalization pass
mcopik Sep 8, 2023
3f76982
Add Fortran AST transformation assigning to each node its parent scope
mcopik Sep 8, 2023
60e9547
Add new Fortran parser function to export pure AST, not SDFG
mcopik Sep 8, 2023
17eaf5a
Add Fortran AST pass to gather all variable declarations inside a scope
mcopik Sep 8, 2023
1be4754
First implementation of the offset normalization pass
mcopik Sep 8, 2023
52f9d51
Merge branch 'fortran_ast_parents' of github.com:spcl/dace into fortr…
mcopik Sep 8, 2023
027f1e2
Remove dead and old code
mcopik Sep 8, 2023
b6d9320
Update the 2D offset normalizer tests to verify offsets on the AST level
mcopik Sep 8, 2023
379dada
Fix handling of ArrayToLoop when involved arrays have offsets
mcopik Sep 8, 2023
c5ce575
Add test verifying a 1D ArrayToLoop transform with offsets
mcopik Sep 8, 2023
2436051
Add test verifying that Fortran offset normalizer works for 1D and 2D…
mcopik Sep 8, 2023
ec77693
Adjust offsets in Array2Loop only when it has offset different than d…
mcopik Sep 8, 2023
b37c1f5
Remove dead code
mcopik Sep 8, 2023
70c33dd
Add support for Fortran modules in scope parent assignment pass
mcopik Sep 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion dace/frontend/fortran/ast_internal_classes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
from typing import Any, List, Tuple, Type, TypeVar, Union, overload
from typing import Any, List, Optional, Tuple, Type, TypeVar, Union, overload

# The node class is the base class for all nodes in the AST. It provides attributes including the line number and fields.
# Attributes are not used when walking the tree, but are useful for debugging and for code generation.
Expand All @@ -11,6 +11,14 @@ def __init__(self, *args, **kwargs): # real signature unknown
self.integrity_exceptions = []
self.read_vars = []
self.written_vars = []
self.parent: Optional[
Union[
Subroutine_Subprogram_Node,
Function_Subprogram_Node,
Main_Program_Node,
Module_Node
]
] = None
for k, v in kwargs.items():
setattr(self, k, v)

Expand Down
174 changes: 148 additions & 26 deletions dace/frontend/fortran/ast_transforms.py
mcopik marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved.

from dace.frontend.fortran import ast_components, ast_internal_classes
from typing import List, Tuple, Set
from typing import Dict, List, Optional, Tuple, Set
import copy


Expand Down Expand Up @@ -310,6 +310,65 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No

return ast_internal_classes.Execution_Part_Node(execution=newbody)

class ParentScopeAssigner(NodeVisitor):
"""
For each node, it assigns its parent scope - program, subroutine, function.

If the parent node is one of the "parent" types, we assign it as the parent.
Otherwise, we look for the parent of my parent to cover nested AST nodes within
a single scope.
"""
def __init__(self):
pass

def visit(self, node: ast_internal_classes.FNode, parent_node: Optional[ast_internal_classes.FNode] = None):

parent_node_types = [
ast_internal_classes.Subroutine_Subprogram_Node,
ast_internal_classes.Function_Subprogram_Node,
ast_internal_classes.Main_Program_Node,
ast_internal_classes.Module_Node
]

if parent_node is not None and type(parent_node) in parent_node_types:
node.parent = parent_node
elif parent_node is not None:
node.parent = parent_node.parent

# Copied from `generic_visit` to recursively parse all leafs
for field, value in iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast_internal_classes.FNode):
self.visit(item, node)
elif isinstance(value, ast_internal_classes.FNode):
self.visit(value, node)

class ScopeVarsDeclarations(NodeVisitor):
"""
Creates a mapping (scope name, variable name) -> variable declaration.

The visitor is used to access information on variable dimension, sizes, and offsets.
"""

def __init__(self):

self.scope_vars: Dict[Tuple[str, str], ast_internal_classes.FNode] = {}

def get_var(self, scope: ast_internal_classes.FNode, variable_name: str) -> ast_internal_classes.FNode:
return self.scope_vars[(self._scope_name(scope), variable_name)]

def visit_Var_Decl_Node(self, node: ast_internal_classes.Var_Decl_Node):

parent_name = self._scope_name(node.parent)
var_name = node.name
self.scope_vars[(parent_name, var_name)] = node

def _scope_name(self, scope: ast_internal_classes.FNode) -> str:
if isinstance(scope, ast_internal_classes.Main_Program_Node):
return scope.name.name.name
else:
return scope.name.name

class IndexExtractorNodeLister(NodeVisitor):
"""
Expand All @@ -336,9 +395,20 @@ class IndexExtractor(NodeTransformer):
Uses the IndexExtractorNodeLister to find all array subscript expressions
in the AST node and its children that have to be extracted into independent expressions
It then creates a new temporary variable for each of them and replaces the index expression with the variable.

Before parsing the AST, the transformation first runs:
- ParentScopeAssigner to ensure that each node knows its scope assigner.
- ScopeVarsDeclarations to aggregate all variable declarations for each function.
"""
def __init__(self, count=0):
def __init__(self, ast: ast_internal_classes.FNode, normalize_offsets: bool = False, count=0):

self.count = count
self.normalize_offsets = normalize_offsets

if normalize_offsets:
ParentScopeAssigner().visit(ast)
self.scope_vars = ScopeVarsDeclarations()
self.scope_vars.visit(ast)

def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):
if node.name.name in ["sqrt", "exp", "pow", "max", "min", "abs", "tanh"]:
Expand Down Expand Up @@ -367,9 +437,11 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
lister.visit(child)
res = lister.nodes
temp = self.count


if res is not None:
for j in res:
for i in j.indices:
for idx, i in enumerate(j.indices):
if isinstance(i, ast_internal_classes.ParDecl_Node):
continue
else:
Expand All @@ -383,16 +455,34 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
line_number=child.line_number)
],
line_number=child.line_number))
newbody.append(
ast_internal_classes.BinOp_Node(
op="=",
lval=ast_internal_classes.Name_Node(name=tmp_name),
rval=ast_internal_classes.BinOp_Node(
op="-",
lval=i,
rval=ast_internal_classes.Int_Literal_Node(value="1"),
line_number=child.line_number),
line_number=child.line_number))
if self.normalize_offsets:

# Find the offset of a variable to which we are assigning
var_name = child.lval.name.name
variable = self.scope_vars.get_var(child.parent, var_name)
offset = variable.offsets[idx]

newbody.append(
ast_internal_classes.BinOp_Node(
op="=",
lval=ast_internal_classes.Name_Node(name=tmp_name),
rval=ast_internal_classes.BinOp_Node(
op="-",
lval=i,
rval=ast_internal_classes.Int_Literal_Node(value=str(offset)),
line_number=child.line_number),
line_number=child.line_number))
else:
newbody.append(
ast_internal_classes.BinOp_Node(
op="=",
lval=ast_internal_classes.Name_Node(name=tmp_name),
rval=ast_internal_classes.BinOp_Node(
op="-",
lval=i,
rval=ast_internal_classes.Int_Literal_Node(value="1"),
line_number=child.line_number),
line_number=child.line_number))
newbody.append(self.visit(child))
return ast_internal_classes.Execution_Part_Node(execution=newbody)

Expand Down Expand Up @@ -646,6 +736,7 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node,
rangepos: list,
count: int,
newbody: list,
scope_vars: ScopeVarsDeclarations,
declaration=True,
is_sum_to_loop=False):
"""
Expand All @@ -662,16 +753,40 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node,

currentindex = 0
indices = []
for i in node.indices:
offsets = scope_vars.get_var(node.parent, node.name.name).offsets

for idx, i in enumerate(node.indices):
if isinstance(i, ast_internal_classes.ParDecl_Node):

if i.type == "ALL":
ranges.append([
ast_internal_classes.Int_Literal_Node(value="1"),
ast_internal_classes.Name_Range_Node(name="f2dace_MAX",
type="INTEGER",
arrname=node.name,
pos=currentindex)
])

lower_boundary = None
if offsets[idx] != 1:
lower_boundary = ast_internal_classes.Int_Literal_Node(value=str(offsets[idx]))
else:
lower_boundary = ast_internal_classes.Int_Literal_Node(value="1")

upper_boundary = ast_internal_classes.Name_Range_Node(name="f2dace_MAX",
type="INTEGER",
arrname=node.name,
pos=currentindex)
"""
When there's an offset, we add MAX_RANGE + offset.
But since the generated loop has `<=` condition, we need to subtract 1.
"""
if offsets[idx] != 1:
upper_boundary = ast_internal_classes.BinOp_Node(
lval=upper_boundary,
op="+",
rval=ast_internal_classes.Int_Literal_Node(value=str(offsets[idx]))
)
upper_boundary = ast_internal_classes.BinOp_Node(
lval=upper_boundary,
op="-",
rval=ast_internal_classes.Int_Literal_Node(value="1")
)
ranges.append([lower_boundary, upper_boundary])

else:
ranges.append([i.range[0], i.range[1]])
rangepos.append(currentindex)
Expand All @@ -693,9 +808,13 @@ class ArrayToLoop(NodeTransformer):
"""
Transforms the AST by removing array expressions and replacing them with loops
"""
def __init__(self):
def __init__(self, ast):
self.count = 0

ParentScopeAssigner().visit(ast)
self.scope_vars = ScopeVarsDeclarations()
self.scope_vars.visit(ast)

def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node):
newbody = []
for child in node.execution:
Expand All @@ -709,15 +828,15 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
val = child.rval
ranges = []
rangepos = []
par_Decl_Range_Finder(current, ranges, rangepos, self.count, newbody, True)
par_Decl_Range_Finder(current, ranges, rangepos, self.count, newbody, self.scope_vars, True)

if res_range is not None and len(res_range) > 0:
rvals = [i for i in mywalk(val) if isinstance(i, ast_internal_classes.Array_Subscript_Node)]
for i in rvals:
rangeposrval = []
rangesrval = []

par_Decl_Range_Finder(i, rangesrval, rangeposrval, self.count, newbody, False)
par_Decl_Range_Finder(i, rangesrval, rangeposrval, self.count, newbody, self.scope_vars, False)

for i, j in zip(ranges, rangesrval):
if i != j:
Expand Down Expand Up @@ -791,8 +910,11 @@ class SumToLoop(NodeTransformer):
"""
Transforms the AST by removing array sums and replacing them with loops
"""
def __init__(self):
def __init__(self, ast):
self.count = 0
ParentScopeAssigner().visit(ast)
self.scope_vars = ScopeVarsDeclarations()
self.scope_vars.visit(ast)

def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node):
newbody = []
Expand All @@ -811,7 +933,7 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
rangeposrval = []
rangesrval = []

par_Decl_Range_Finder(val, rangesrval, rangeposrval, self.count, newbody, False, True)
par_Decl_Range_Finder(val, rangesrval, rangeposrval, self.count, newbody, self.scope_vars, False, True)

range_index = 0
body = ast_internal_classes.BinOp_Node(lval=current,
Expand Down
50 changes: 43 additions & 7 deletions dace/frontend/fortran/fortran_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def translate(self, node: ast_internal_classes.FNode, sdfg: SDFG):
for i in node:
self.translate(i, sdfg)
else:
warnings.warn("WARNING:", node.__class__.__name__)
warnings.warn(f"WARNING: {node.__class__.__name__}")

def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG):
"""
Expand Down Expand Up @@ -1015,10 +1015,46 @@ def vardecl2sdfg(self, node: ast_internal_classes.Var_Decl_Node, sdfg: SDFG):
if node.name not in self.contexts[sdfg.name].containers:
self.contexts[sdfg.name].containers.append(node.name)

def create_ast_from_string(
source_string: str,
sdfg_name: str,
transform: bool = False,
normalize_offsets: bool = False
):
"""
Creates an AST from a Fortran file in a string
:param source_string: The fortran file as a string
:param sdfg_name: The name to be given to the resulting SDFG
:return: The resulting AST

"""
parser = pf().create(std="f2008")
reader = fsr(source_string)
ast = parser(reader)
tables = SymbolTable
own_ast = ast_components.InternalFortranAst(ast, tables)
program = own_ast.create_ast(ast)

functions_and_subroutines_builder = ast_transforms.FindFunctionAndSubroutines()
functions_and_subroutines_builder.visit(program)
functions_and_subroutines = functions_and_subroutines_builder.nodes

if transform:
program = ast_transforms.functionStatementEliminator(program)
program = ast_transforms.CallToArray(functions_and_subroutines_builder.nodes).visit(program)
program = ast_transforms.CallExtractor().visit(program)
program = ast_transforms.SignToIf().visit(program)
program = ast_transforms.ArrayToLoop(program).visit(program)
program = ast_transforms.SumToLoop(program).visit(program)
program = ast_transforms.ForDeclarer().visit(program)
program = ast_transforms.IndexExtractor(program, normalize_offsets).visit(program)

return (program, own_ast)

def create_sdfg_from_string(
source_string: str,
sdfg_name: str,
normalize_offsets: bool = False
):
"""
Creates an SDFG from a fortran file in a string
Expand All @@ -1040,10 +1076,10 @@ def create_sdfg_from_string(
program = ast_transforms.CallToArray(functions_and_subroutines_builder.nodes).visit(program)
program = ast_transforms.CallExtractor().visit(program)
program = ast_transforms.SignToIf().visit(program)
program = ast_transforms.ArrayToLoop().visit(program)
program = ast_transforms.SumToLoop().visit(program)
program = ast_transforms.ArrayToLoop(program).visit(program)
program = ast_transforms.SumToLoop(program).visit(program)
program = ast_transforms.ForDeclarer().visit(program)
program = ast_transforms.IndexExtractor().visit(program)
program = ast_transforms.IndexExtractor(program, normalize_offsets).visit(program)
ast2sdfg = AST_translator(own_ast, __file__)
sdfg = SDFG(sdfg_name)
ast2sdfg.top_level = program
Expand Down Expand Up @@ -1082,10 +1118,10 @@ def create_sdfg_from_fortran_file(source_string: str):
program = ast_transforms.CallToArray(functions_and_subroutines_builder.nodes).visit(program)
program = ast_transforms.CallExtractor().visit(program)
program = ast_transforms.SignToIf().visit(program)
program = ast_transforms.ArrayToLoop().visit(program)
program = ast_transforms.SumToLoop().visit(program)
program = ast_transforms.ArrayToLoop(program).visit(program)
program = ast_transforms.SumToLoop(program).visit(program)
program = ast_transforms.ForDeclarer().visit(program)
program = ast_transforms.IndexExtractor().visit(program)
program = ast_transforms.IndexExtractor(program).visit(program)
ast2sdfg = AST_translator(own_ast, __file__)
sdfg = SDFG(source_string)
ast2sdfg.top_level = program
Expand Down
Loading