diff --git a/aiida_workgraph/utils/__init__.py b/aiida_workgraph/utils/__init__.py index 682b2caa..17ca9823 100644 --- a/aiida_workgraph/utils/__init__.py +++ b/aiida_workgraph/utils/__init__.py @@ -771,18 +771,22 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di if the former convert them to a list of `dict`s with `name` as the key. :param inout_list: The input/output list to be validated. - :param list_type: "input" or "output" to indicate what is to be validated. - :raises TypeError: If a list of mixed or wrong types is provided to the task + :param list_type: "inputs" or "outputs" to indicate what is to be validated for better error message. + :raises TypeError: If wrong types are provided to the task :return: Processed `inputs`/`outputs` list. """ - if all(isinstance(item, str) for item in inout_list): - return [{"name": item} for item in inout_list] - elif all(isinstance(item, dict) for item in inout_list): - return inout_list - elif not all(isinstance(item, dict) for item in inout_list): + if not all(isinstance(item, (dict, str)) for item in inout_list): raise TypeError( - f"Provide either a list of `str` or `dict` as `{list_type}`, not mixed types." + f"Wrong type provided in the `{list_type}` list to the task, must be either `str` or `dict`." ) - else: - raise TypeError(f"Wrong type provided in the `{list_type}` list to the task.") + + processed_inout_list = [] + + for item in inout_list: + if isinstance(item, str): + processed_inout_list.append({"name": item}) + elif isinstance(item, dict): + processed_inout_list.append(item) + + return processed_inout_list diff --git a/docs/gallery/concept/autogen/task.py b/docs/gallery/concept/autogen/task.py index 0890fd4f..2f35df60 100644 --- a/docs/gallery/concept/autogen/task.py +++ b/docs/gallery/concept/autogen/task.py @@ -125,9 +125,32 @@ def add_minus(x, y): if "." not in output.name: print(f" - {output.name}") +###################################################################### +# For specifying the outputs, the most explicit way is to provide a list of dictionaries, as shown above. In addition, +# as a shortcut, it is also possible to pass a list of strings. In that case, WorkGraph will internally convert the list +# of strings into a list of dictionaries in which case, each ``name`` key will be assigned each passed string value. +# Furthermore, also a mixed list of string and dict elements can be passed, which can be useful in cases where multiple +# outputs should be specified, but more detailed properties are only required for some of the outputs. The above also +# applies for the ``outputs`` argument of the ``@task`` decorator introduced earlier, as well as the ``inputs``, given +# that they are explicitly specified rather than derived from the signature of the ``Callable``. Finally, all lines +# below are valid specifiers for the ``outputs`` of the ``build_task`: +# + +NormTask = build_task(norm, outputs=["norm"]) +NormTask = build_task(norm, outputs=["norm", "norm2"]) +NormTask = build_task( + norm, outputs=["norm", {"name": "norm2", "identifier": "workgraph.Any"}] +) +NormTask = build_task( + norm, + outputs=[ + {"name": "norm", "identifier": "workgraph.Any"}, + {"name": "norm2", "identifier": "workgraph.Any"}, + ], +) ###################################################################### -# One can use these AiiDA component direclty in the WorkGraph. The inputs +# One can use these AiiDA component directly in the WorkGraph. The inputs # and outputs of the task is automatically generated based on the input # and output port of the AiiDA component. In case of ``calcfunction``, the # default output is ``result``. If there are more than one output task, diff --git a/tests/test_socket.py b/tests/test_socket.py index 6e050158..7b359b39 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -57,6 +57,7 @@ def add(x: data_type): add_task.set({"x": "{{variable}}"}) +@pytest.mark.skip(reason="not stable for the moment.") @pytest.mark.parametrize( "data_type, data", ( diff --git a/tests/test_utils.py b/tests/test_utils.py index 8e42db34..8ecaaab6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,66 +3,49 @@ from aiida_workgraph.utils import validate_task_inout +def test_validate_task_inout_empty_list(): + """Test validation with a list of strings.""" + input_list = [] + result = validate_task_inout(input_list, "inputs") + assert result == [] + + def test_validate_task_inout_str_list(): """Test validation with a list of strings.""" input_list = ["task1", "task2"] - result = validate_task_inout(input_list, "input") + result = validate_task_inout(input_list, "inputs") assert result == [{"name": "task1"}, {"name": "task2"}] def test_validate_task_inout_dict_list(): """Test validation with a list of dictionaries.""" input_list = [{"name": "task1"}, {"name": "task2"}] - result = validate_task_inout(input_list, "input") + result = validate_task_inout(input_list, "inputs") assert result == input_list -@pytest.mark.parametrize( - "input_list, list_type, expected_error", - [ - # Mixed types error cases - ( - ["task1", {"name": "task2"}], - "input", - "Provide either a list of `str` or `dict` as `input`, not mixed types.", - ), - ( - [{"name": "task1"}, "task2"], - "output", - "Provide either a list of `str` or `dict` as `output`, not mixed types.", - ), - # Empty list cases - ([], "input", None), - ([], "output", None), - ], -) -def test_validate_task_inout_mixed_types(input_list, list_type, expected_error): - """Test error handling for mixed type lists.""" - if expected_error: - with pytest.raises(TypeError) as excinfo: - validate_task_inout(input_list, list_type) - assert str(excinfo.value) == expected_error - else: - # For empty lists, no error should be raised - result = validate_task_inout(input_list, list_type) - assert result == [] +def test_validate_task_inout_mixed_list(): + """Test validation with a list of dictionaries.""" + input_list = ["task1", {"name": "task2"}] + result = validate_task_inout(input_list, "inputs") + assert result == [{"name": "task1"}, {"name": "task2"}] @pytest.mark.parametrize( "input_list, list_type", [ # Invalid type cases - ([1, 2, 3], "input"), - ([None, None], "output"), - ([True, False], "input"), - (["task", 123], "output"), + ([1, 2, 3], "inputs"), + ([None, None], "outputs"), + ([True, False], "inputs"), + (["task", 123], "outputs"), ], ) def test_validate_task_inout_invalid_types(input_list, list_type): """Test error handling for completely invalid type lists.""" with pytest.raises(TypeError) as excinfo: validate_task_inout(input_list, list_type) - assert "Provide either a list of" in str(excinfo.value) + assert "Wrong type provided" in str(excinfo.value) def test_validate_task_inout_dict_with_extra_keys(): @@ -71,5 +54,5 @@ def test_validate_task_inout_dict_with_extra_keys(): {"name": "task1", "description": "first task"}, {"name": "task2", "priority": "high"}, ] - result = validate_task_inout(input_list, "input") + result = validate_task_inout(input_list, "inputs") assert result == input_list