diff --git a/src/psyclone/tests/psyir/transformations/scalarization_trans_test.py b/src/psyclone/tests/psyir/transformations/scalarization_trans_test.py index c18a915ba9..399e619dc0 100644 --- a/src/psyclone/tests/psyir/transformations/scalarization_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/scalarization_trans_test.py @@ -314,6 +314,79 @@ def test_scalarizationtrans_value_unused_after_loop(fortran_reader): node, var_accesses) + # Test being a while condition correctly counts as being used. + code = '''subroutine test() + use my_mod + integer :: i + integer :: k + integer, dimension(1:100) :: arr + integer, dimension(1:100) :: b + + do i = 1, 100 + arr(i) = exp(arr(i)) + b(i) = arr(i) * 3 + end do + do i = 1, 100 + do while(b(i) < 256) + b(i) = arr(i) * arr(i) + arr(i) = arr(i) * 2 + end do + end do + end subroutine test + ''' + psyir = fortran_reader.psyir_from_source(code) + node = psyir.children[0].children[0] + var_accesses = VariablesAccessInfo(nodes=node.loop_body) + keys = list(var_accesses.keys()) + # Test b + assert var_accesses[keys[2]].var_name == "b" + assert not ScalarizationTrans._value_unused_after_loop(keys[2], + node, + var_accesses) + + # Test being a loop start/stop/step condition correctly counts + # as being used. + code = '''subroutine test() + use my_mod + integer :: i + integer :: k + integer, dimension(1:100) :: arr + integer, dimension(1:100) :: b + integer, dimension(1:100) :: c + integer, dimension(1:100, 1:100) :: d + + do i = 1, 100 + arr(i) = exp(arr(i)) + b(i) = arr(i) * 3 + c(i) = i + end do + do i = 1, 100 + do k = arr(i), b(i), c(i) + d(i,k) = i + end do + end do + end subroutine test + ''' + psyir = fortran_reader.psyir_from_source(code) + node = psyir.children[0].children[0] + var_accesses = VariablesAccessInfo(nodes=node.loop_body) + keys = list(var_accesses.keys()) + # Test arr + assert var_accesses[keys[1]].var_name == "arr" + assert not ScalarizationTrans._value_unused_after_loop(keys[1], + node, + var_accesses) + # Test b + assert var_accesses[keys[2]].var_name == "b" + assert not ScalarizationTrans._value_unused_after_loop(keys[2], + node, + var_accesses) + # Test b + assert var_accesses[keys[3]].var_name == "c" + assert not ScalarizationTrans._value_unused_after_loop(keys[3], + node, + var_accesses) + def test_scalarization_trans_apply(fortran_reader, fortran_writer, tmpdir): ''' Test the application of the scalarization transformation.'''