diff --git a/gokart/__init__.py b/gokart/__init__.py index 25e54f41..26d9555f 100644 --- a/gokart/__init__.py +++ b/gokart/__init__.py @@ -1,7 +1,7 @@ from gokart.build import WorkerSchedulerFactory, build # noqa:F401 from gokart.info import make_tree_info, tree_info # noqa:F401 from gokart.pandas_type_config import PandasTypeConfig # noqa:F401 -from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter # noqa:F401 +from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, SerializableParameter, TaskInstanceParameter # noqa:F401 from gokart.run import run # noqa:F401 from gokart.task import TaskOnKart # noqa:F401 from gokart.testing import test_run # noqa:F401 diff --git a/gokart/parameter.py b/gokart/parameter.py index de8c9556..02a91048 100644 --- a/gokart/parameter.py +++ b/gokart/parameter.py @@ -1,6 +1,7 @@ import bz2 import json from logging import getLogger +from typing import Generic, Protocol, TypeVar import luigi from luigi import task_register @@ -87,3 +88,34 @@ def __init__(self, *args, **kwargs): def _parser_kwargs(self, *args, **kwargs): # type: ignore return luigi.Parameter._parser_kwargs(*args, *kwargs) + + +T = TypeVar('T') + + +class Serializable(Protocol): + def gokart_serialize(self) -> str: + """Implement this method to serialize the object as an parameter + You can omit some fields from results of serialization if you want to ignore changes of them + """ + ... + + @classmethod + def gokart_deserialize(cls: type[T], s: str) -> T: + """Implement this method to deserialize the object from a string""" + ... + + +S = TypeVar('S', bound=Serializable) + + +class SerializableParameter(luigi.Parameter, Generic[S]): + def __init__(self, object_type: type[S], *args, **kwargs): + self._object_type = object_type + super().__init__(*args, **kwargs) + + def parse(self, s: str) -> S: + return self._object_type.gokart_deserialize(s) + + def serialize(self, x: S) -> str: + return x.gokart_serialize() diff --git a/test/test_serializable_parameter.py b/test/test_serializable_parameter.py new file mode 100644 index 00000000..57e56e2e --- /dev/null +++ b/test/test_serializable_parameter.py @@ -0,0 +1,83 @@ +import json +import tempfile +from dataclasses import asdict, dataclass + +import luigi +import pytest +from luigi.cmdline_parser import CmdlineParser +from mypy import api + +from gokart import SerializableParameter, TaskOnKart +from test.config import PYPROJECT_TOML + + +@dataclass(frozen=True) +class Config: + foo: int + bar: str + + def gokart_serialize(self) -> str: + # dict is ordered in Python 3.7+ + return json.dumps(asdict(self)) + + @classmethod + def gokart_deserialize(cls, s: str) -> 'Config': + return cls(**json.loads(s)) + + +class SerializableParameterWithOutDefault(TaskOnKart): + task_namespace = __name__ + config: Config = SerializableParameter(object_type=Config) + + def run(self): + self.dump(self.config) + + +class SerializableParameterWithDefault(TaskOnKart): + task_namespace = __name__ + config: Config = SerializableParameter(object_type=Config, default=Config(foo=1, bar='bar')) + + def run(self): + self.dump(self.config) + + +class TestSerializableParameter: + def test_default(self): + with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithDefault']) as cp: + assert cp.get_task_obj().config == Config(foo=1, bar='bar') + + def test_parse_param(self): + with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault', '--config', '{"foo": 100, "bar": "val"}']) as cp: + assert cp.get_task_obj().config == Config(foo=100, bar='val') + + def test_missing_parameter(self): + with pytest.raises(luigi.parameter.MissingParameterException): + with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault']) as cp: + cp.get_task_obj() + + def test_value_error(self): + with pytest.raises(ValueError): + with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault', '--config', 'Foo']) as cp: + cp.get_task_obj() + + def test_expected_one_argument_error(self): + with pytest.raises(SystemExit): + with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault', '--config']) as cp: + cp.get_task_obj() + + def test_mypy(self): + """check invalid object cannot used for SerializableParameter""" + + test_code = """ +import gokart + +class InvalidClass: + ... + +gokart.SerializableParameter(object_type=InvalidClass) + """ + with tempfile.NamedTemporaryFile(suffix='.py') as test_file: + test_file.write(test_code.encode('utf-8')) + test_file.flush() + result = api.run(['--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name]) + assert 'Value of type variable "S" of "SerializableParameter" cannot be "InvalidClass" [type-var]' in result[0]