Skip to content

Commit

Permalink
feat: add serializable parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
Hironori Yamamoto committed Nov 27, 2024
1 parent 5138294 commit 4d5e42e
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 1 deletion.
2 changes: 1 addition & 1 deletion gokart/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from gokart.build import WorkerSchedulerFactory, build # noqa:F401
from gokart.info import make_tree_info, tree_info # noqa:F401
from gokart.pandas_type_config import PandasTypeConfig # noqa:F401
from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter # noqa:F401
from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, SerializableObjectParameter, TaskInstanceParameter # noqa:F401
from gokart.run import run # noqa:F401
from gokart.task import TaskOnKart # noqa:F401
from gokart.testing import test_run # noqa:F401
Expand Down
26 changes: 26 additions & 0 deletions gokart/parameter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import bz2
import json
from logging import getLogger
from typing import Generic, Protocol, TypeVar

import luigi
from luigi import task_register
Expand Down Expand Up @@ -87,3 +88,28 @@ def __init__(self, *args, **kwargs):

def _parser_kwargs(self, *args, **kwargs): # type: ignore
return luigi.Parameter._parser_kwargs(*args, *kwargs)


T = TypeVar('T')


class Serializable(Protocol):
def serialize(self) -> str: ...

@classmethod
def deserialize(cls: type[T], s: str) -> T: ...


S = TypeVar('S', bound=Serializable)


class SerializableObjectParameter(luigi.Parameter, Generic[S]):
def __init__(self, object_type: type[S], *args, **kwargs):
self._object_type = object_type
luigi.Parameter.__init__(self, *args, **kwargs)

def parse(self, s: str) -> S:
return self._object_type.deserialize(s)

def serialize(self, x: S) -> str:
return x.serialize()
89 changes: 89 additions & 0 deletions test/test_serializable_object_parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import json
import tempfile
from dataclasses import asdict, dataclass

import luigi
import pytest
from luigi.cmdline_parser import CmdlineParser
from mypy import api

from gokart import SerializableObjectParameter, TaskOnKart, build
from test.config import PYPROJECT_TOML


@dataclass(frozen=True)
class Config:
foo: int
bar: str

def serialize(self) -> str:
# dict is ordered in Python 3.7+
return json.dumps(asdict(self))

@classmethod
def deserialize(cls, s: str) -> 'Config':
return cls(**json.loads(s))


class WithOutDefault(TaskOnKart):
task_namespace = __name__
config: Config = SerializableObjectParameter(object_type=Config)

def run(self):
self.dump(self.config)


class WithDefault(TaskOnKart):
task_namespace = __name__
config: Config = SerializableObjectParameter(object_type=Config, default=Config(foo=1, bar='bar'))

def run(self):
self.dump(self.config)


class TestSerializableObjectParameter:
def test_default(self):
with CmdlineParser.global_instance([f'{__name__}.WithDefault']) as cp:
assert cp.get_task_obj().config == Config(foo=1, bar='bar')

def test_parse_param(self):
with CmdlineParser.global_instance([f'{__name__}.WithOutDefault', '--config', '{"foo": 100, "bar": "val"}']) as cp:
assert cp.get_task_obj().config == Config(foo=100, bar='val')

def test_missing_parameter(self):
with pytest.raises(luigi.parameter.MissingParameterException):
with CmdlineParser.global_instance([f'{__name__}.WithOutDefault']) as cp:
cp.get_task_obj()

def test_value_error(self):
with pytest.raises(ValueError):
with CmdlineParser.global_instance([f'{__name__}.WithOutDefault', '--config', 'Foo']) as cp:
cp.get_task_obj()

def test_expected_one_argment_error(self):
with pytest.raises(SystemExit):
with CmdlineParser.global_instance([f'{__name__}.WithOutDefault', '--config']) as cp:
cp.get_task_obj()

def test_build(self):
WithOutDefault.clear_instance_cache()
config = Config(foo=100, bar='val')
actual = build(WithOutDefault(config=config))
assert actual == config

def test_mypy(self):
"""check invalid object cannot used for SerializableObjectParameter"""

test_code = """
import gokart
class InvalidClass:
...
gokart.SerializableObjectParameter(object_type=InvalidClass)
"""
with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
test_file.write(test_code.encode('utf-8'))
test_file.flush()
result = api.run(['--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
assert 'Value of type variable "S" of "SerializableObjectParameter" cannot be "InvalidClass" [type-var]' in result[0]

0 comments on commit 4d5e42e

Please sign in to comment.