From 1d43042ddb8dfceab690e0d839ebfe23c471829a Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Tue, 19 Nov 2024 09:51:59 +0100 Subject: [PATCH 1/7] Allow mixed `str`/`dict` inputs/outputs to tasks --- aiida_workgraph/utils/__init__.py | 24 +++++++------ tests/test_utils.py | 57 +++++++++++-------------------- 2 files changed, 34 insertions(+), 47 deletions(-) diff --git a/aiida_workgraph/utils/__init__.py b/aiida_workgraph/utils/__init__.py index 682b2caa..ef4cc052 100644 --- a/aiida_workgraph/utils/__init__.py +++ b/aiida_workgraph/utils/__init__.py @@ -771,18 +771,22 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di if the former convert them to a list of `dict`s with `name` as the key. :param inout_list: The input/output list to be validated. - :param list_type: "input" or "output" to indicate what is to be validated. - :raises TypeError: If a list of mixed or wrong types is provided to the task + :param list_type: "input" or "output" to indicate what is to be validated for better error message. + :raises TypeError: If wrong types are provided to the task :return: Processed `inputs`/`outputs` list. """ - if all(isinstance(item, str) for item in inout_list): - return [{"name": item} for item in inout_list] - elif all(isinstance(item, dict) for item in inout_list): - return inout_list - elif not all(isinstance(item, dict) for item in inout_list): + if not all(isinstance(item, (dict, str)) for item in inout_list): raise TypeError( - f"Provide either a list of `str` or `dict` as `{list_type}`, not mixed types." + f"Wrong type provided in the `{list_type}` list to the task, must be either `str` or `dict`." ) - else: - raise TypeError(f"Wrong type provided in the `{list_type}` list to the task.") + + processed_inout_list = [] + + for item in inout_list: + if isinstance(item, str): + processed_inout_list.append({"name": item}) + elif isinstance(item, dict): + processed_inout_list.append(item) + + return processed_inout_list diff --git a/tests/test_utils.py b/tests/test_utils.py index 8e42db34..8ecaaab6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,66 +3,49 @@ from aiida_workgraph.utils import validate_task_inout +def test_validate_task_inout_empty_list(): + """Test validation with a list of strings.""" + input_list = [] + result = validate_task_inout(input_list, "inputs") + assert result == [] + + def test_validate_task_inout_str_list(): """Test validation with a list of strings.""" input_list = ["task1", "task2"] - result = validate_task_inout(input_list, "input") + result = validate_task_inout(input_list, "inputs") assert result == [{"name": "task1"}, {"name": "task2"}] def test_validate_task_inout_dict_list(): """Test validation with a list of dictionaries.""" input_list = [{"name": "task1"}, {"name": "task2"}] - result = validate_task_inout(input_list, "input") + result = validate_task_inout(input_list, "inputs") assert result == input_list -@pytest.mark.parametrize( - "input_list, list_type, expected_error", - [ - # Mixed types error cases - ( - ["task1", {"name": "task2"}], - "input", - "Provide either a list of `str` or `dict` as `input`, not mixed types.", - ), - ( - [{"name": "task1"}, "task2"], - "output", - "Provide either a list of `str` or `dict` as `output`, not mixed types.", - ), - # Empty list cases - ([], "input", None), - ([], "output", None), - ], -) -def test_validate_task_inout_mixed_types(input_list, list_type, expected_error): - """Test error handling for mixed type lists.""" - if expected_error: - with pytest.raises(TypeError) as excinfo: - validate_task_inout(input_list, list_type) - assert str(excinfo.value) == expected_error - else: - # For empty lists, no error should be raised - result = validate_task_inout(input_list, list_type) - assert result == [] +def test_validate_task_inout_mixed_list(): + """Test validation with a list of dictionaries.""" + input_list = ["task1", {"name": "task2"}] + result = validate_task_inout(input_list, "inputs") + assert result == [{"name": "task1"}, {"name": "task2"}] @pytest.mark.parametrize( "input_list, list_type", [ # Invalid type cases - ([1, 2, 3], "input"), - ([None, None], "output"), - ([True, False], "input"), - (["task", 123], "output"), + ([1, 2, 3], "inputs"), + ([None, None], "outputs"), + ([True, False], "inputs"), + (["task", 123], "outputs"), ], ) def test_validate_task_inout_invalid_types(input_list, list_type): """Test error handling for completely invalid type lists.""" with pytest.raises(TypeError) as excinfo: validate_task_inout(input_list, list_type) - assert "Provide either a list of" in str(excinfo.value) + assert "Wrong type provided" in str(excinfo.value) def test_validate_task_inout_dict_with_extra_keys(): @@ -71,5 +54,5 @@ def test_validate_task_inout_dict_with_extra_keys(): {"name": "task1", "description": "first task"}, {"name": "task2", "priority": "high"}, ] - result = validate_task_inout(input_list, "input") + result = validate_task_inout(input_list, "inputs") assert result == input_list From fdc63dd8a873bb29c3aee472bd0ffcc58539e944 Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Tue, 19 Nov 2024 18:46:57 +0100 Subject: [PATCH 2/7] Apply suggestions from code review Co-authored-by: Xing Wang --- aiida_workgraph/utils/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiida_workgraph/utils/__init__.py b/aiida_workgraph/utils/__init__.py index ef4cc052..17ca9823 100644 --- a/aiida_workgraph/utils/__init__.py +++ b/aiida_workgraph/utils/__init__.py @@ -771,7 +771,7 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di if the former convert them to a list of `dict`s with `name` as the key. :param inout_list: The input/output list to be validated. - :param list_type: "input" or "output" to indicate what is to be validated for better error message. + :param list_type: "inputs" or "outputs" to indicate what is to be validated for better error message. :raises TypeError: If wrong types are provided to the task :return: Processed `inputs`/`outputs` list. """ From 40ec66c7a27afeef8c2ae1dc6bfc4fd888728cc6 Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Thu, 21 Nov 2024 08:17:37 +0100 Subject: [PATCH 3/7] Add explanation and examples for mixed types --- docs/gallery/concept/autogen/task.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/docs/gallery/concept/autogen/task.py b/docs/gallery/concept/autogen/task.py index 0890fd4f..2f35df60 100644 --- a/docs/gallery/concept/autogen/task.py +++ b/docs/gallery/concept/autogen/task.py @@ -125,9 +125,32 @@ def add_minus(x, y): if "." not in output.name: print(f" - {output.name}") +###################################################################### +# For specifying the outputs, the most explicit way is to provide a list of dictionaries, as shown above. In addition, +# as a shortcut, it is also possible to pass a list of strings. In that case, WorkGraph will internally convert the list +# of strings into a list of dictionaries in which case, each ``name`` key will be assigned each passed string value. +# Furthermore, also a mixed list of string and dict elements can be passed, which can be useful in cases where multiple +# outputs should be specified, but more detailed properties are only required for some of the outputs. The above also +# applies for the ``outputs`` argument of the ``@task`` decorator introduced earlier, as well as the ``inputs``, given +# that they are explicitly specified rather than derived from the signature of the ``Callable``. Finally, all lines +# below are valid specifiers for the ``outputs`` of the ``build_task`: +# + +NormTask = build_task(norm, outputs=["norm"]) +NormTask = build_task(norm, outputs=["norm", "norm2"]) +NormTask = build_task( + norm, outputs=["norm", {"name": "norm2", "identifier": "workgraph.Any"}] +) +NormTask = build_task( + norm, + outputs=[ + {"name": "norm", "identifier": "workgraph.Any"}, + {"name": "norm2", "identifier": "workgraph.Any"}, + ], +) ###################################################################### -# One can use these AiiDA component direclty in the WorkGraph. The inputs +# One can use these AiiDA component directly in the WorkGraph. The inputs # and outputs of the task is automatically generated based on the input # and output port of the AiiDA component. In case of ``calcfunction``, the # default output is ``result``. If there are more than one output task, From a85d11059ae3f2021f37e297727c0c4eb4336bba Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Tue, 26 Nov 2024 10:31:25 +0100 Subject: [PATCH 4/7] Skip `test_socket_validate` test for now. --- tests/test_socket.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_socket.py b/tests/test_socket.py index 6e050158..7b359b39 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -57,6 +57,7 @@ def add(x: data_type): add_task.set({"x": "{{variable}}"}) +@pytest.mark.skip(reason="not stable for the moment.") @pytest.mark.parametrize( "data_type, data", ( From 9a337aaf30ee91f902c41c6fff6881a10cf9c744 Mon Sep 17 00:00:00 2001 From: superstar54 Date: Mon, 2 Dec 2024 11:02:00 +0100 Subject: [PATCH 5/7] update docs --- docs/gallery/concept/autogen/task.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/docs/gallery/concept/autogen/task.py b/docs/gallery/concept/autogen/task.py index 2f35df60..35eec338 100644 --- a/docs/gallery/concept/autogen/task.py +++ b/docs/gallery/concept/autogen/task.py @@ -54,15 +54,27 @@ def multiply(x, y): ###################################################################### # If you want to change the name of the output ports, or if there are more # than one output. You can define the outputs explicitly. For example: -# ``{"name": "sum", "identifier": "workgraph.Any"}``, where the ``identifier`` -# indicates the data type. The data type tell the code how to display the -# port in the GUI, validate the data, and serialize data into database. We -# use ``workgraph.Any`` for any data type. For the moment, the data validation is + + +# define the outputs explicitly +@task(outputs=["sum", "diff"]) +def add_minus(x, y): + return {"sum": x + y, "difference": x - y} + + +print("Inputs:", add_minus.task().inputs.keys()) +print("Outputs:", add_minus.task().outputs.keys()) + +###################################################################### +# One can also add an ``identifier`` to indicates the data type. The data +# type tell the code how to display the port in the GUI, validate the data, +# and serialize data into database. +# We use ``workgraph.Any`` for any data type. For the moment, the data validation is # experimentally supported, and the GUI display is not implemented. Thus, # I suggest you to always ``workgraph.Any`` for the port. # -# define add calcfunction task +# define the outputs with identifier @task( outputs=[ {"name": "sum", "identifier": "workgraph.Any"}, @@ -73,10 +85,6 @@ def add_minus(x, y): return {"sum": x + y, "difference": x - y} -print("Inputs:", add_minus.task().inputs.keys()) -print("Outputs:", add_minus.task().outputs.keys()) - - ###################################################################### # Then, one can use the task inside the WorkGraph: # From fdc99723b33ef2e3e215f5328db71a530b659479 Mon Sep 17 00:00:00 2001 From: superstar54 Date: Mon, 2 Dec 2024 11:12:39 +0100 Subject: [PATCH 6/7] do not skip test_socket_validate --- tests/test_socket.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_socket.py b/tests/test_socket.py index 6fc8fd50..fea72d4d 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -57,7 +57,6 @@ def add(x: data_type): add_task.set({"x": "{{variable}}"}) -@pytest.mark.skip(reason="not stable for the moment.") @pytest.mark.parametrize( "data_type, data", ( From 6531bec51ed205077021deb2c7bd7d9c326b575a Mon Sep 17 00:00:00 2001 From: superstar54 Date: Mon, 2 Dec 2024 11:30:31 +0100 Subject: [PATCH 7/7] validate task inputs and outputs when creating task --- aiida_workgraph/decorator.py | 10 ++++++++-- tests/test_decorator.py | 12 ++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/aiida_workgraph/decorator.py b/aiida_workgraph/decorator.py index 791f7d8c..c5551b48 100644 --- a/aiida_workgraph/decorator.py +++ b/aiida_workgraph/decorator.py @@ -557,8 +557,8 @@ def decorator_task( identifier: Optional[str] = None, task_type: str = "Normal", properties: Optional[List[Tuple[str, str]]] = None, - inputs: Optional[List[Tuple[str, str]]] = None, - outputs: Optional[List[Tuple[str, str]]] = None, + inputs: Optional[List[str | dict]] = None, + outputs: Optional[List[str | dict]] = None, error_handlers: Optional[List[Dict[str, Any]]] = None, catalog: str = "Others", ) -> Callable: @@ -574,6 +574,12 @@ def decorator_task( outputs (list): task outputs """ + if inputs: + inputs = validate_task_inout(inputs, "inputs") + + if outputs: + outputs = validate_task_inout(outputs, "outputs") + def decorator(func): nonlocal identifier, task_type diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 2d96c708..ff2237e6 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -3,6 +3,18 @@ from typing import Callable +def test_custom_outputs(): + """Test custom outputs.""" + + @task(outputs=["sum", {"name": "product", "identifier": "workgraph.any"}]) + def add_multiply(x, y): + return {"sum": x + y, "product": x * y} + + n = add_multiply.task() + assert "sum" in n.outputs.keys() + assert "product" in n.outputs.keys() + + @pytest.fixture(params=["decorator_factory", "decorator"]) def task_calcfunction(request): if request.param == "decorator_factory":