Skip to content

Commit

Permalink
Added another test
Browse files Browse the repository at this point in the history
  • Loading branch information
LonelyCat124 committed Jan 10, 2025
1 parent a506244 commit 75b77d7
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 18 deletions.
45 changes: 41 additions & 4 deletions src/psyclone/psyir/transformations/scalarization_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ class ScalarizationTrans(LoopTrans):

@staticmethod
def _is_local_array(signature, var_accesses):
'''
:param signature: The signature to check if it is a local array symbol
or not.
:type signature: :py:class:`psyclone.core.Signature`
:param var_accesses: The VariableAccessesInfo object containing
signature.
:type var_accesses: :py:class:`psyclone.core.VariablesAccessInfo`
:returns bool: whether the symbol corresponding to signature is a
local symbol or not.
'''
if not var_accesses[signature].is_array():
return False
base_symbol = var_accesses[signature].all_accesses[0].node.symbol
Expand All @@ -59,6 +69,16 @@ def _is_local_array(signature, var_accesses):

@staticmethod
def _have_same_unmodified_index(signature, var_accesses):
'''
:param signature: The signature to check.
:type signature: :py:class:`psyclone.core.Signature`
:param var_accesses: The VariableAccessesInfo object containing
signature.
:type var_accesses: :py:class:`psyclone.core.VariablesAccessInfo`
:returns bool: whether all the array accesses to signature use the
same index, and whether the index is unmodified in
the code region.
'''
array_indices = None
scalarizable = True
for access in var_accesses[signature].all_accesses:
Expand All @@ -84,13 +104,30 @@ def _have_same_unmodified_index(signature, var_accesses):
return scalarizable

@staticmethod
def _check_first_access_is_write(sig, var_accesses):
if var_accesses[sig].is_written_first():
def _check_first_access_is_write(signature, var_accesses):
'''
:param signature: The signature to check.
:type signature: :py:class:`psyclone.core.Signature`
:param var_accesses: The VariableAccessesInfo object containing
signature.
:type var_accesses: :py:class:`psyclone.core.VariablesAccessInfo`
:returns bool: whether the first access to signature is a write.
'''
if var_accesses[signature].is_written_first():
return True
return False

@staticmethod
def _value_unused_after_loop(sig, node, var_accesses):
def _value_unused_after_loop(sig, var_accesses):
'''
:param sig: The signature to check.
:type sig: :py:class:`psyclone.core.Signature`
:param var_accesses: The VariableAccessesInfo object containing
signature.
:type var_accesses: :py:class:`psyclone.core.VariablesAccessInfo`
:returns bool: whether the value computed in the loop containing
sig is read from after the loop.
'''
# Find the last access of the signature
last_access = var_accesses[sig].all_accesses[-1].node
# Find the next accesses to this symbol
Expand Down Expand Up @@ -218,7 +255,7 @@ def apply(self, node, options=None):
# Check the values written to these arrays are not used after this loop
finalised_targets = filter(
lambda sig:
ScalarizationTrans._value_unused_after_loop(sig, node,
ScalarizationTrans._value_unused_after_loop(sig,
var_accesses),
potential_targets)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,10 @@ def test_scalarizationtrans_value_unused_after_loop(fortran_reader):
# Test arr
assert var_accesses[keys[1]].var_name == "arr"
assert ScalarizationTrans._value_unused_after_loop(keys[1],
node,
var_accesses)
# Test b
assert var_accesses[keys[2]].var_name == "b"
assert not ScalarizationTrans._value_unused_after_loop(keys[2],
node,
var_accesses)

# Test we ignore array next_access if they're in an if statement
Expand Down Expand Up @@ -238,12 +236,10 @@ def test_scalarizationtrans_value_unused_after_loop(fortran_reader):
# Test arr
assert var_accesses[keys[1]].var_name == "arr"
assert ScalarizationTrans._value_unused_after_loop(keys[1],
node,
var_accesses)
# Test b
assert var_accesses[keys[2]].var_name == "b"
assert ScalarizationTrans._value_unused_after_loop(keys[2],
node,
var_accesses)
# Test we don't ignore array next_access if they're in an if statement
# that is an ancestor of the loop we're scalarizing
Expand Down Expand Up @@ -271,12 +267,10 @@ def test_scalarizationtrans_value_unused_after_loop(fortran_reader):
# Test arr
assert var_accesses[keys[1]].var_name == "arr"
assert ScalarizationTrans._value_unused_after_loop(keys[1],
node,
var_accesses)
# Test b
assert var_accesses[keys[2]].var_name == "b"
assert ScalarizationTrans._value_unused_after_loop(keys[2],
node,
var_accesses)

# Test we don't ignore array next_access if they have an ancestor
Expand Down Expand Up @@ -306,12 +300,10 @@ def test_scalarizationtrans_value_unused_after_loop(fortran_reader):
# Test arr
assert var_accesses[keys[1]].var_name == "arr"
assert ScalarizationTrans._value_unused_after_loop(keys[1],
node,
var_accesses)
# Test b
assert var_accesses[keys[2]].var_name == "b"
assert not ScalarizationTrans._value_unused_after_loop(keys[2],
node,
var_accesses)

# Test being a while condition correctly counts as being used.
Expand Down Expand Up @@ -341,7 +333,6 @@ def test_scalarizationtrans_value_unused_after_loop(fortran_reader):
# Test b
assert var_accesses[keys[2]].var_name == "b"
assert not ScalarizationTrans._value_unused_after_loop(keys[2],
node,
var_accesses)

# Test being a loop start/stop/step condition correctly counts
Expand Down Expand Up @@ -374,17 +365,14 @@ def test_scalarizationtrans_value_unused_after_loop(fortran_reader):
# Test arr
assert var_accesses[keys[1]].var_name == "arr"
assert not ScalarizationTrans._value_unused_after_loop(keys[1],
node,
var_accesses)
# Test b
assert var_accesses[keys[2]].var_name == "b"
assert not ScalarizationTrans._value_unused_after_loop(keys[2],
node,
var_accesses)
# Test c
assert var_accesses[keys[3]].var_name == "c"
assert not ScalarizationTrans._value_unused_after_loop(keys[3],
node,
var_accesses)

# Test being a symbol in a Codeblock counts as used
Expand Down Expand Up @@ -414,7 +402,6 @@ def test_scalarizationtrans_value_unused_after_loop(fortran_reader):
# Test arr
assert var_accesses[keys[1]].var_name == "arr"
assert not ScalarizationTrans._value_unused_after_loop(keys[1],
node,
var_accesses)

# Test being in an IfBlock condition counts as used.
Expand Down Expand Up @@ -446,7 +433,6 @@ def test_scalarizationtrans_value_unused_after_loop(fortran_reader):
# Test arr
assert var_accesses[keys[1]].var_name == "arr"
assert not ScalarizationTrans._value_unused_after_loop(keys[1],
node,
var_accesses)


Expand Down Expand Up @@ -497,3 +483,60 @@ def test_scalarization_trans_apply(fortran_reader, fortran_writer, tmpdir):
out = fortran_writer(psyir)
assert correct in out
assert Compile(tmpdir).string_compiles(out)

# Use in if/else where the if has write only followup and
# the else has a read - shouldn't scalarise b.
code = '''subroutine test()
integer :: i
integer :: k
integer, dimension(1:100) :: arr
integer, dimension(1:100) :: b
integer, dimension(1:100) :: c
do i = 1, 100
arr(i) = i
arr(i) = exp(arr(i))
k = i
b(i) = arr(i) * 3
c(k) = i
end do
do i = 1, 100
if(c(i) > 50) then
b(i) = c(i)
else
b(i) = b(i) + c(i)
end if
end do
end subroutine
'''
strans = ScalarizationTrans()
psyir = fortran_reader.psyir_from_source(code)

loop = psyir.children[0].children[0]
strans.apply(loop)
correct = '''subroutine test()
integer :: i
integer :: k
integer, dimension(100) :: arr
integer, dimension(100) :: b
integer, dimension(100) :: c
integer :: arr_scalar
do i = 1, 100, 1
arr_scalar = i
arr_scalar = EXP(arr_scalar)
k = i
b(i) = arr_scalar * 3
c(k) = i
enddo
do i = 1, 100, 1
if (c(i) > 50) then
b(i) = c(i)
else
b(i) = b(i) + c(i)
end if
enddo'''
out = fortran_writer(psyir)
print(out)
assert correct in out
assert Compile(tmpdir).string_compiles(out)

0 comments on commit 75b77d7

Please sign in to comment.