Skip to content

Commit

Permalink
#1664: Improve tests over the sync/wait clause
Browse files Browse the repository at this point in the history
  • Loading branch information
svalat committed Mar 16, 2023
1 parent ccd0048 commit d01c00c
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 18 deletions.
24 changes: 12 additions & 12 deletions src/psyclone/psyir/nodes/acc_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,12 @@ def async_queue(self):
@async_queue.setter
def async_queue(self, async_queue):
'''
:param bool async_stream: wheter or not to add the 'async' close
:param bool async_queue: wheter or not to add the 'async' close
and attach to which stream.
'''
# check
if async_queue != None and not isinstance(async_queue, (bool, Signature, int)):
raise TypeError("Invalid async_stream value, expect Signature or integer or None or False")
raise TypeError("Invalid async_queue value, expect Signature or integer or None or False")

# assign
self._async_queue = async_queue
Expand Down Expand Up @@ -335,12 +335,12 @@ def async_queue(self):
@async_queue.setter
def async_queue(self, async_queue):
'''
:param bool async_stream: wheter or not to add the 'async' close
:param bool async_queue: wheter or not to add the 'async' close
and attach to which stream.
'''
# check
if async_queue != None and not isinstance(async_queue, (bool, Signature, int)):
raise TypeError("Invalid async_stream value, expect Signature or integer or None or False")
raise TypeError("Invalid async_queue value, expect Signature or integer or None or False")

# assign
self._async_queue = async_queue
Expand Down Expand Up @@ -610,10 +610,10 @@ class ACCKernelsDirective(ACCRegionDirective):
:type parent: sub-class of :py:class:`psyclone.psyir.nodes.Node`
:param bool default_present: whether or not to add the "default(present)" \
clause to the kernels directive.
:param async_stream: Make the directive asynchonous and attached to the given
:param async_queue: Make the directive asynchonous and attached to the given
steam identified by an ID or by a variable name pointing to
an integer.
:type async_stream: bool/Signature/int
:type async_queue: bool/Signature/int
:raises NotImplementedError: if default_present is False.
Expand Down Expand Up @@ -659,12 +659,12 @@ def async_queue(self):
@async_queue.setter
def async_queue(self, async_queue):
'''
:param bool async_stream: wheter or not to add the 'async' close
:param bool async_queue: wheter or not to add the 'async' close
and attach to which stream.
'''
# check
if async_queue != None and not isinstance(async_queue, (bool, Signature, int)):
raise TypeError("Invalid async_stream value, expect Signature or integer or None or False")
raise TypeError("Invalid async_queue value, expect Signature or integer or None or False")

# assign
self._async_queue = async_queue
Expand Down Expand Up @@ -867,10 +867,10 @@ class ACCUpdateDirective(ACCStandaloneDirective):
clause on the update directive (this instructs the
directive to silently ignore any variables that are not
on the device).
:param async_stream: Make the directive asynchonous and attached to the given
:param async_queue: Make the directive asynchonous and attached to the given
steam identified by an ID or by a variable name pointing to
an integer.
:type async_stream: None/str/int
:type async_queue: None/str/int
:type if_present: Optional[bool]
'''

Expand Down Expand Up @@ -985,12 +985,12 @@ def if_present(self, if_present):
@async_queue.setter
def async_queue(self, async_queue):
'''
:param bool async_stream: whether or not to add the 'async' close
:param bool async_queue: whether or not to add the 'async' close
and attach to which stream.
'''
# check
if async_queue != None and not isinstance(async_queue, (bool, Signature, int)):
raise TypeError("Invalid async_stream value, expect Signature or integer or None or False")
raise TypeError("Invalid async_queue value, expect Signature or integer or None or False")

# assign
self._async_queue = async_queue
Expand Down
66 changes: 60 additions & 6 deletions src/psyclone/tests/psyir/nodes/acc_directives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from psyclone.psyGen import PSyFactory
from psyclone.psyir.nodes import ACCRoutineDirective, \
ACCKernelsDirective, Schedule, ACCUpdateDirective, ACCLoopDirective, \
ACCWaitDirective, Routine
ACCWaitDirective, Routine, ACCParallelDirective
from psyclone.psyir.symbols import SymbolTable
from psyclone.transformations import ACCEnterDataTrans, ACCParallelTrans, \
ACCKernelsTrans
Expand All @@ -66,10 +66,8 @@ def setup():
yield
Config._instance = None


# Class ACCEnterDataDirective start


# (1/4) Method gen_code
def test_accenterdatadirective_gencode_1():
'''Test that an OpenACC Enter Data directive, when added to a schedule
Expand Down Expand Up @@ -455,8 +453,9 @@ def test_accupdatedirective_equality():

# Class ACCWaitDirective

# (1/1) Method __init__
def test_accwaitdirective_init():
'''Test init of ACCWaitDirective.'''

directive1 = ACCWaitDirective(None)
assert directive1.wait_queue == None

Expand All @@ -472,8 +471,9 @@ def test_accwaitdirective_init():
with pytest.raises(TypeError):
directive5 = ACCWaitDirective(3.5)

# (1/1) Method begin_string
def test_accwaitdirective_begin_string():
'''Test begin_string of ACCWaitDirective.'''

directive1 = ACCWaitDirective(None)
assert directive1.begin_string() == "acc wait"

Expand All @@ -486,11 +486,65 @@ def test_accwaitdirective_begin_string():
directive4 = ACCWaitDirective(Signature("variable_name"))
assert directive4.begin_string() == "acc wait (variable_name)"

# (1/1) Method gencode
def test_accwaitdirective_gencode():
'''Test gen code of ACCWaitDirective'''

_, info = parse(os.path.join(BASE_PATH, "1_single_invoke.f90"))
psy = PSyFactory(distributed_memory=False).create(info)
routines = psy.container.walk(Routine)
routines[0].children.append(ACCWaitDirective(1))
code = str(psy.gen)
assert '$acc wait (1)' in code

def test_accwaitdirective_eq():
'''Test the __eq__ implementation of ACCWaitDirective.'''

# build some
directive1 = ACCWaitDirective(1)
directive2 = ACCWaitDirective(1)
directive3 = ACCWaitDirective(Signature('stream1'))

# check equality
assert directive1 == directive2
assert not (directive1 == directive3)

# async keyword on all classes

@pytest.mark.parametrize("directive_type", [ACCKernelsDirective, ACCParallelDirective, ACCUpdateDirective])
def test_directives_async_queue(directive_type):
'''Validate the various usage of async_queue parameter'''

# args
args = []
if directive_type == ACCUpdateDirective:
args = [[Signature('x')], 'host']

# set value at init
directive = directive_type(*args, async_queue=1)
assert directive.async_queue == 1
assert 'async(1)' in directive.begin_string()

# change value to true
directive.async_queue = True
assert directive.async_queue == True
assert 'async()' in directive.begin_string()

# change value to False
directive.async_queue = False
assert directive.async_queue == False
assert not 'async()' in directive.begin_string()

# change value to None
directive.async_queue = None
assert directive.async_queue == None
assert not 'async()' in directive.begin_string()

# change value afterward
directive.async_queue = Signature("stream")
assert directive.async_queue == Signature("stream")
assert 'async(stream)' in directive.begin_string()

# put wrong type
with pytest.raises(TypeError) as error:
directive.async_queue = 3.5
assert "Invalid async_queue" in str(error)

0 comments on commit d01c00c

Please sign in to comment.