Skip to content

Commit

Permalink
Closes Bears-R-Us#3771: register_commands.py to handle generic scalar…
Browse files Browse the repository at this point in the history
… type (Bears-R-Us#3772)

Co-authored-by: Amanda Potts <ajpotts@users.noreply.github.com>
  • Loading branch information
ajpotts and ajpotts committed Sep 17, 2024
1 parent e00f10b commit 8d994e6
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 42 deletions.
10 changes: 10 additions & 0 deletions registration-config.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@
"bool",
"bigint"
]
},
"scalar": {
"dtype": [
"int",
"uint",
"uint(8)",
"real",
"bool",
"bigint"
]
}
}
}
10 changes: 10 additions & 0 deletions src/registry/Commands.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ param regConfig = """
"bool",
"bigint"
]
},
"scalar": {
"dtype": [
"int",
"uint",
"uint(8)",
"real",
"bool",
"bigint"
]
}
}
}
Expand Down
90 changes: 48 additions & 42 deletions src/registry/register_commands.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import chapel
import sys
import json
import itertools
import json
import sys

import chapel

DEFAULT_MODS = ["MsgProcessing", "GenSymIO"]

Expand Down Expand Up @@ -210,6 +211,7 @@ def info_tuple(formal):
gen_formals.append(formal_info)
else:
con_formals.append(formal_info)

return con_formals, gen_formals


Expand All @@ -220,9 +222,7 @@ def clean_stamp_name(name):
return name.translate(str.maketrans("[](),=", "______"))


def stamp_generic_command(
generic_proc_name, prefix, module_name, formals, line_num, is_user_proc
):
def stamp_generic_command(generic_proc_name, prefix, module_name, formals, line_num, is_user_proc):
"""
Create code to stamp out and register a generic command using a generic
procedure, and a set values for its generic formals.
Expand Down Expand Up @@ -295,9 +295,7 @@ def parse_param_class_value(value):
if isinstance(value, list):
for v in value:
if not isinstance(v, (int, float, str)):
raise ValueError(
f"Invalid parameter value type ({type(v)}) in list '{value}'"
)
raise ValueError(f"Invalid parameter value type ({type(v)}) in list '{value}'")
return value
elif isinstance(value, int):
return [
Expand All @@ -313,9 +311,7 @@ def parse_param_class_value(value):
if isinstance(vals, list):
return vals
else:
raise ValueError(
f"Could not create a list of parameter values from '{value}'"
)
raise ValueError(f"Could not create a list of parameter values from '{value}'")
else:
raise ValueError(f"Invalid parameter value type ({type(value)}) for '{value}'")

Expand Down Expand Up @@ -353,9 +349,7 @@ def generic_permutations(config, gen_formals):
+ "please check the 'parameter_classes' field in the configuration file"
)

to_permute[formal_name] = parse_param_class_value(
config["parameter_classes"][pclass][pname]
)
to_permute[formal_name] = parse_param_class_value(config["parameter_classes"][pclass][pname])

return permutations(to_permute)

Expand Down Expand Up @@ -446,6 +440,28 @@ def unpack_scalar_arg(arg_name, arg_type):
return f"\tvar {arg_name} = {ARGS_FORMAL_NAME}['{arg_name}'].toScalar({arg_type});"


def unpack_scalar_arg_with_generic(arg_name, array_count):
"""
Generate the code to unpack a scalar argument
'scalar_count' is used to generate unique names when
a procedure has multiple array-symbol formals
Example:
```
var x = msgArgs['x'].toScalar(scalar_dtype_0);
```
Returns the chapel code, and the specifications for the
'dtype' and type-constructor arguments
"""
dtype_arg_name = "scalar_dtype_" + str(array_count)
return (
unpack_scalar_arg(arg_name, dtype_arg_name),
[(dtype_arg_name, "type", None, None)],
)


def unpack_tuple_arg(arg_name, tuple_size, scalar_type):
"""
Generate the code to unpack a tuple argument
Expand Down Expand Up @@ -492,8 +508,7 @@ def gen_signature(user_proc_name, generic_args=None):
if generic_args:
name = "ark_reg_" + user_proc_name + "_generic"
arg_strings = [
f"{kind} {name}: {ft}" if ft else f"{kind} {name}"
for name, kind, ft, _ in generic_args
f"{kind} {name}: {ft}" if ft else f"{kind} {name}" for name, kind, ft, _ in generic_args
]
proc = f"proc {name}(cmd: string, {ARGS_FORMAL_NAME}: {ARGS_FORMAL_TYPE}, {SYMTAB_FORMAL_NAME}: {SYMTAB_FORMAL_TYPE}, {', '.join(arg_strings)}): {RESPONSE_TYPE_NAME} throws {'{'}"
else:
Expand All @@ -511,11 +526,13 @@ def gen_arg_unpacking(formals):
unpack_lines = []
generic_args = []
array_arg_counter = 0
scalar_arg_counter = 0

array_domain_queries = {}
array_dtype_queries = {}

for fname, fintent, ftype, finfo in formals:

if ftype in chapel_scalar_types:
unpack_lines.append(unpack_scalar_arg(fname, ftype))
elif ftype == "<array>":
Expand Down Expand Up @@ -556,12 +573,14 @@ def gen_arg_unpacking(formals):
unpack_lines.append(unpack_tuple_arg(fname, tsize, ttype))
else:
if ftype in array_dtype_queries.keys():
unpack_lines.append(
unpack_scalar_arg(fname, array_dtype_queries[ftype])
)

unpack_lines.append(unpack_scalar_arg(fname, array_dtype_queries[ftype]))
else:
# TODO: fully handle generic user-defined types
unpack_lines.append(unpack_user_symbol(fname, ftype))
code, scalar_args = unpack_scalar_arg_with_generic(fname, scalar_arg_counter)
unpack_lines.append(code)
generic_args += scalar_args
scalar_arg_counter += 1

return ("\n".join(unpack_lines), generic_args)

Expand Down Expand Up @@ -652,14 +671,10 @@ def gen_command_proc(name, return_type, formals, mod_name):
arg_unpack, command_formals = gen_arg_unpacking(formals)
is_generic_command = len(command_formals) > 0
signature, cmd_name = gen_signature(name, command_formals)
fn_call, result_name = gen_user_function_call(
name, [f[0] for f in formals], mod_name, return_type
)
fn_call, result_name = gen_user_function_call(name, [f[0] for f in formals], mod_name, return_type)

# get the names of the array-elt-type queries in the formals
array_etype_queries = [
f[3][1] for f in formals if (f[2] == "<array>" and f[3] is not None)
]
array_etype_queries = [f[3][1] for f in formals if (f[2] == "<array>" and f[3] is not None)]

# assume the returned type is a symbol if it's an identifier that is not a scalar or type-query reference
# or if it is a `SymEntry` type-constructor call
Expand All @@ -678,30 +693,22 @@ def gen_command_proc(name, return_type, formals, mod_name):
)
)
returns_array = (
return_type
and isinstance(return_type, chapel.BracketLoop)
and return_type.is_maybe_array_type()
return_type and isinstance(return_type, chapel.BracketLoop) and return_type.is_maybe_array_type()
)

if returns_array:
symbol_creation, result_name = gen_symbol_creation(
ARRAY_ENTRY_CLASS_NAME, result_name
)
symbol_creation, result_name = gen_symbol_creation(ARRAY_ENTRY_CLASS_NAME, result_name)
else:
symbol_creation = ""

response = gen_response(result_name, returns_symbol or returns_array)

command_proc = "\n".join(
[signature, arg_unpack, fn_call, symbol_creation, response, "}"]
)
command_proc = "\n".join([signature, arg_unpack, fn_call, symbol_creation, response, "}"])

return (command_proc, cmd_name, is_generic_command, command_formals)


def stamp_out_command(
config, formals, name, cmd_prefix, mod_name, line_num, is_user_proc
):
def stamp_out_command(config, formals, name, cmd_prefix, mod_name, line_num, is_user_proc):
"""
Yield instantiations of a generic command with using the
values from the configuration file
Expand All @@ -723,9 +730,7 @@ def stamp_out_command(
formal_perms = generic_permutations(config, formals)

for fp in formal_perms:
stamp = stamp_generic_command(
name, cmd_prefix, mod_name, fp, line_num, is_user_proc
)
stamp = stamp_generic_command(name, cmd_prefix, mod_name, fp, line_num, is_user_proc)
yield stamp


Expand Down Expand Up @@ -782,6 +787,7 @@ def register_commands(config, source_files):
(cmd_proc, cmd_name, is_generic_cmd, cmd_gen_formals) = gen_command_proc(
name, fn.return_type(), con_formals, mod_name
)

file_stamps.append(cmd_proc)
count += 1

Expand Down

0 comments on commit 8d994e6

Please sign in to comment.