Skip to content

Commit

Permalink
Updated tests for coverage. Now need to do some more complex function…
Browse files Browse the repository at this point in the history
…ality tests
  • Loading branch information
LonelyCat124 committed Jan 10, 2025
1 parent 2988358 commit dc69373
Showing 1 changed file with 73 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,79 @@ def test_scalarizationtrans_value_unused_after_loop(fortran_reader):
node,
var_accesses)

# Test being a while condition correctly counts as being used.
code = '''subroutine test()
use my_mod
integer :: i
integer :: k
integer, dimension(1:100) :: arr
integer, dimension(1:100) :: b
do i = 1, 100
arr(i) = exp(arr(i))
b(i) = arr(i) * 3
end do
do i = 1, 100
do while(b(i) < 256)
b(i) = arr(i) * arr(i)
arr(i) = arr(i) * 2
end do
end do
end subroutine test
'''
psyir = fortran_reader.psyir_from_source(code)
node = psyir.children[0].children[0]
var_accesses = VariablesAccessInfo(nodes=node.loop_body)
keys = list(var_accesses.keys())
# Test b
assert var_accesses[keys[2]].var_name == "b"
assert not ScalarizationTrans._value_unused_after_loop(keys[2],
node,
var_accesses)

# Test being a loop start/stop/step condition correctly counts
# as being used.
code = '''subroutine test()
use my_mod
integer :: i
integer :: k
integer, dimension(1:100) :: arr
integer, dimension(1:100) :: b
integer, dimension(1:100) :: c
integer, dimension(1:100, 1:100) :: d
do i = 1, 100
arr(i) = exp(arr(i))
b(i) = arr(i) * 3
c(i) = i
end do
do i = 1, 100
do k = arr(i), b(i), c(i)
d(i,k) = i
end do
end do
end subroutine test
'''
psyir = fortran_reader.psyir_from_source(code)
node = psyir.children[0].children[0]
var_accesses = VariablesAccessInfo(nodes=node.loop_body)
keys = list(var_accesses.keys())
# Test arr
assert var_accesses[keys[1]].var_name == "arr"
assert not ScalarizationTrans._value_unused_after_loop(keys[1],
node,
var_accesses)
# Test b
assert var_accesses[keys[2]].var_name == "b"
assert not ScalarizationTrans._value_unused_after_loop(keys[2],
node,
var_accesses)
# Test b
assert var_accesses[keys[3]].var_name == "c"
assert not ScalarizationTrans._value_unused_after_loop(keys[3],
node,
var_accesses)


def test_scalarization_trans_apply(fortran_reader, fortran_writer, tmpdir):
''' Test the application of the scalarization transformation.'''
Expand Down

0 comments on commit dc69373

Please sign in to comment.