Skip to content

Commit

Permalink
Changes to use filter to make the code easier to understand
Browse files Browse the repository at this point in the history
  • Loading branch information
LonelyCat124 committed May 20, 2024
1 parent fece2bd commit ed75417
Show file tree
Hide file tree
Showing 2 changed files with 266 additions and 243 deletions.
201 changes: 104 additions & 97 deletions src/psyclone/psyir/transformations/scalarization_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,102 +47,90 @@

class ScalarizationTrans(LoopTrans):

def _find_potential_scalarizable_array_symbols(self, node, var_accesses):

potential_arrays = []
signatures = var_accesses.all_signatures
for signature in signatures:
# Skip over non-arrays
if not var_accesses[signature].is_array():
continue
# Skip over non-local symbols
base_symbol = var_accesses[signature].all_accesses[0].node.symbol
if not base_symbol.is_automatic:
continue
array_indices = None
scalarizable = True
for access in var_accesses[signature].all_accesses:
if array_indices is None:
array_indices = access.component_indices
# For some reason using == on the component_lists doesn't work
elif array_indices[:] != access.component_indices[:]:
scalarizable = False
break
@staticmethod
def _is_local_array(signature, var_accesses):
if not var_accesses[signature].is_array():
return False
base_symbol = var_accesses[signature].all_accesses[0].node.symbol
if not base_symbol.is_automatic:
return False

return True

@staticmethod
def _have_same_unmodified_index(signature, var_accesses):
array_indices = None
scalarizable = True
for access in var_accesses[signature].all_accesses:
if array_indices is None:
array_indices = access.component_indices
# For some reason using == on the component_lists doesn't work
elif array_indices[:] != access.component_indices[:]:
scalarizable = False
break
# For each index, we need to check they're not written to in
# the loop.
flattened_indices = list(itertools.chain.from_iterable(
array_indices))
for index in flattened_indices:
sig, _ = index.get_signature_and_indices()
if var_accesses[sig].is_written():
scalarizable = False
break
if scalarizable:
potential_arrays.append(signature)

return potential_arrays

def _check_first_access_is_write(self, node, var_accesses, potentials):
potential_arrays = []

for signature in potentials:
if var_accesses[signature].is_written_first():
potential_arrays.append(signature)

return potential_arrays

def _check_valid_following_access(self, node, var_accesses, potentials):
potential_arrays = []

for signature in potentials:
# Find the last access of each signature
last_access = var_accesses[signature].all_accesses[-1].node
# Find the next access to this symbol
next_access = last_access.next_access()
# If we don't use this again then its valid
if next_access is None:
potential_arrays.append(signature)
continue
# If we do and the next_access has an ancestor IfBlock
# that isn't an ancestor of the loop then its not valid since
# we aren't tracking down what the condition-dependent next
# use really is.
if_ancestor = next_access.ancestor(IfBlock)

# If abs_position of if_ancestor is > node.abs_position
# its not an ancestor of us.
if (if_ancestor is not None and
if_ancestor.abs_position > node.abs_position):
# Not a valid next_access pattern.
continue

# If next access is the LHS of an assignment, we need to
# check that it doesn't also appear on the RHS. If so its
# not a valid access
# I'm not sure this code is reachable
# if (isinstance(next_access.parent, Assignment) and
# next_access.parent.lhs is next_access and
# (next_access.next_access() is not None and
# next_access.next_access().ancestor(Assignment) is
# next_access.parent)):
# continue

# If next access is the RHS of an assignment then we need to
# skip it
ancestor_assign = next_access.ancestor(Assignment)
if (ancestor_assign is not None and
ancestor_assign.lhs is not next_access):
continue

# If it has an ancestor that is a CodeBlock or Call or Kern
# then we can't guarantee anything, so we remove it.
if (next_access.ancestor((CodeBlock, Call, Kern))
is not None):
continue

potential_arrays.append(signature)

return potential_arrays
# Index may not be a Reference, so we need to loop over the
# References
for ref in index.walk(Reference):
sig, _ = ref.get_signature_and_indices()
if var_accesses[sig].is_written():
scalarizable = False
break

return scalarizable

@staticmethod
def _check_first_access_is_write(sig, var_accesses):
if var_accesses[sig].is_written_first():
return True
return False

@staticmethod
def _value_unused_after_loop(sig, node, var_accesses):
# Find the last access of the signature
last_access = var_accesses[sig].all_accesses[-1].node
# Find the next access to this symbol
next_access = last_access.next_access()
# If we don't use this again then this can be scalarized
if next_access is None:
return True

# If the next_access has an ancestor IfBlock and
# that isn't an ancestor of the loop then its not valid since
# we aren't tracking down what the condition-dependent next
# use really is.
if_ancestor = next_access.ancestor(IfBlock)
# If abs_position of if_ancestor is > node.abs_position
# its not an ancestor of us.
# Handles:
# if (some_condition) then
# x = next_access[i] + 1
if (if_ancestor is not None and
if_ancestor.abs_position > node.abs_position):
# Not a valid next_access pattern.
return False

# If next access is the RHS of an assignment then we need to
# skip it
# Handles:
# a = next_access[i] + 1
ancestor_assign = next_access.ancestor(Assignment)
if (ancestor_assign is not None and
ancestor_assign.lhs is not next_access):
return False

# If it has an ancestor that is a CodeBlock or Call or Kern
# then we can't guarantee anything, so we remove it.
# Handles: call my_func(next_access)
if (next_access.ancestor((CodeBlock, Call, Kern))
is not None):
return False

return True

def apply(self, node, options=None):
'''Apply the scalarization transformation to a loop.
Expand Down Expand Up @@ -182,17 +170,36 @@ def apply(self, node, options=None):

# Find all the ararys that are only accessed by a single index, and
# that index is only read inside the loop.
potential_targets = self._find_potential_scalarizable_array_symbols(
node, var_accesses)
potential_targets = filter(
lambda sig:
ScalarizationTrans._is_local_array(sig, var_accesses),
var_accesses)
potential_targets = filter(
lambda sig:
ScalarizationTrans._have_same_unmodified_index(sig,
var_accesses),
potential_targets)
# potential_targets = self._find_potential_scalarizable_array_symbols(
# node, var_accesses)

# Now we need to check the first access is a write and remove those
# that aren't.
potential_targets = self._check_first_access_is_write(
node, var_accesses, potential_targets)
potential_targets = filter(
lambda sig:
ScalarizationTrans._check_first_access_is_write(sig,
var_accesses),
potential_targets)
# potential_targets = self._check_first_access_is_write(
# node, var_accesses, potential_targets)

# Check the values written to these arrays are not used after this loop
finalised_targets = self._check_valid_following_access(
node, var_accesses, potential_targets)
finalised_targets = filter(
lambda sig:
ScalarizationTrans._value_unused_after_loop(sig, node,
var_accesses),
potential_targets)
# finalised_targets = self._check_valid_following_access(
# node, var_accesses, potential_targets)

routine_table = node.ancestor(Routine).symbol_table
# For each finalised target we can replace them with a scalarized
Expand Down
Loading

0 comments on commit ed75417

Please sign in to comment.