Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow mixed str/dict inputs/outputs to tasks #345

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,18 +771,22 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di
if the former convert them to a list of `dict`s with `name` as the key.

:param inout_list: The input/output list to be validated.
:param list_type: "input" or "output" to indicate what is to be validated.
:raises TypeError: If a list of mixed or wrong types is provided to the task
:param list_type: "inputs" or "outputs" to indicate what is to be validated for better error message.
:raises TypeError: If wrong types are provided to the task
:return: Processed `inputs`/`outputs` list.
"""

if all(isinstance(item, str) for item in inout_list):
return [{"name": item} for item in inout_list]
elif all(isinstance(item, dict) for item in inout_list):
return inout_list
elif not all(isinstance(item, dict) for item in inout_list):
if not all(isinstance(item, (dict, str)) for item in inout_list):
raise TypeError(
f"Provide either a list of `str` or `dict` as `{list_type}`, not mixed types."
f"Wrong type provided in the `{list_type}` list to the task, must be either `str` or `dict`."
)
else:
raise TypeError(f"Wrong type provided in the `{list_type}` list to the task.")

processed_inout_list = []

for item in inout_list:
if isinstance(item, str):
processed_inout_list.append({"name": item})
elif isinstance(item, dict):
processed_inout_list.append(item)

return processed_inout_list
25 changes: 24 additions & 1 deletion docs/gallery/concept/autogen/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,32 @@ def add_minus(x, y):
if "." not in output.name:
print(f" - {output.name}")

######################################################################
# For specifying the outputs, the most explicit way is to provide a list of dictionaries, as shown above. In addition,
# as a shortcut, it is also possible to pass a list of strings. In that case, WorkGraph will internally convert the list
# of strings into a list of dictionaries in which case, each ``name`` key will be assigned each passed string value.
# Furthermore, also a mixed list of string and dict elements can be passed, which can be useful in cases where multiple
# outputs should be specified, but more detailed properties are only required for some of the outputs. The above also
# applies for the ``outputs`` argument of the ``@task`` decorator introduced earlier, as well as the ``inputs``, given
# that they are explicitly specified rather than derived from the signature of the ``Callable``. Finally, all lines
# below are valid specifiers for the ``outputs`` of the ``build_task`:
#

NormTask = build_task(norm, outputs=["norm"])
NormTask = build_task(norm, outputs=["norm", "norm2"])
NormTask = build_task(
norm, outputs=["norm", {"name": "norm2", "identifier": "workgraph.Any"}]
)
NormTask = build_task(
norm,
outputs=[
{"name": "norm", "identifier": "workgraph.Any"},
{"name": "norm2", "identifier": "workgraph.Any"},
],
)

######################################################################
# One can use these AiiDA component direclty in the WorkGraph. The inputs
# One can use these AiiDA component directly in the WorkGraph. The inputs
# and outputs of the task is automatically generated based on the input
# and output port of the AiiDA component. In case of ``calcfunction``, the
# default output is ``result``. If there are more than one output task,
Expand Down
1 change: 1 addition & 0 deletions tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def add(x: data_type):
add_task.set({"x": "{{variable}}"})


@pytest.mark.skip(reason="not stable for the moment.")
@pytest.mark.parametrize(
"data_type, data",
(
Expand Down
57 changes: 20 additions & 37 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,66 +3,49 @@
from aiida_workgraph.utils import validate_task_inout


def test_validate_task_inout_empty_list():
"""Test validation with a list of strings."""
input_list = []
result = validate_task_inout(input_list, "inputs")
assert result == []


def test_validate_task_inout_str_list():
"""Test validation with a list of strings."""
input_list = ["task1", "task2"]
result = validate_task_inout(input_list, "input")
result = validate_task_inout(input_list, "inputs")
assert result == [{"name": "task1"}, {"name": "task2"}]


def test_validate_task_inout_dict_list():
"""Test validation with a list of dictionaries."""
input_list = [{"name": "task1"}, {"name": "task2"}]
result = validate_task_inout(input_list, "input")
result = validate_task_inout(input_list, "inputs")
assert result == input_list


@pytest.mark.parametrize(
"input_list, list_type, expected_error",
[
# Mixed types error cases
(
["task1", {"name": "task2"}],
"input",
"Provide either a list of `str` or `dict` as `input`, not mixed types.",
),
(
[{"name": "task1"}, "task2"],
"output",
"Provide either a list of `str` or `dict` as `output`, not mixed types.",
),
# Empty list cases
([], "input", None),
([], "output", None),
],
)
def test_validate_task_inout_mixed_types(input_list, list_type, expected_error):
"""Test error handling for mixed type lists."""
if expected_error:
with pytest.raises(TypeError) as excinfo:
validate_task_inout(input_list, list_type)
assert str(excinfo.value) == expected_error
else:
# For empty lists, no error should be raised
result = validate_task_inout(input_list, list_type)
assert result == []
def test_validate_task_inout_mixed_list():
"""Test validation with a list of dictionaries."""
input_list = ["task1", {"name": "task2"}]
result = validate_task_inout(input_list, "inputs")
assert result == [{"name": "task1"}, {"name": "task2"}]


@pytest.mark.parametrize(
"input_list, list_type",
[
# Invalid type cases
([1, 2, 3], "input"),
([None, None], "output"),
([True, False], "input"),
(["task", 123], "output"),
([1, 2, 3], "inputs"),
([None, None], "outputs"),
([True, False], "inputs"),
(["task", 123], "outputs"),
],
)
def test_validate_task_inout_invalid_types(input_list, list_type):
"""Test error handling for completely invalid type lists."""
with pytest.raises(TypeError) as excinfo:
validate_task_inout(input_list, list_type)
assert "Provide either a list of" in str(excinfo.value)
assert "Wrong type provided" in str(excinfo.value)


def test_validate_task_inout_dict_with_extra_keys():
Expand All @@ -71,5 +54,5 @@ def test_validate_task_inout_dict_with_extra_keys():
{"name": "task1", "description": "first task"},
{"name": "task2", "priority": "high"},
]
result = validate_task_inout(input_list, "input")
result = validate_task_inout(input_list, "inputs")
assert result == input_list
Loading