Skip to content

Commit

Permalink
feat: add serializable parameter (#411)
Browse files Browse the repository at this point in the history
Co-authored-by: Hironori Yamamoto <hironori-yamamoto@m3.com>
  • Loading branch information
hiro-o918 and Hironori Yamamoto authored Nov 29, 2024
1 parent 5138294 commit 15737f6
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 1 deletion.
2 changes: 1 addition & 1 deletion gokart/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from gokart.build import WorkerSchedulerFactory, build # noqa:F401
from gokart.info import make_tree_info, tree_info # noqa:F401
from gokart.pandas_type_config import PandasTypeConfig # noqa:F401
from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter # noqa:F401
from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, SerializableParameter, TaskInstanceParameter # noqa:F401
from gokart.run import run # noqa:F401
from gokart.task import TaskOnKart # noqa:F401
from gokart.testing import test_run # noqa:F401
Expand Down
32 changes: 32 additions & 0 deletions gokart/parameter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import bz2
import json
from logging import getLogger
from typing import Generic, Protocol, TypeVar

import luigi
from luigi import task_register
Expand Down Expand Up @@ -87,3 +88,34 @@ def __init__(self, *args, **kwargs):

def _parser_kwargs(self, *args, **kwargs): # type: ignore
return luigi.Parameter._parser_kwargs(*args, *kwargs)


T = TypeVar('T')


class Serializable(Protocol):
def gokart_serialize(self) -> str:
"""Implement this method to serialize the object as an parameter
You can omit some fields from results of serialization if you want to ignore changes of them
"""
...

@classmethod
def gokart_deserialize(cls: type[T], s: str) -> T:
"""Implement this method to deserialize the object from a string"""
...


S = TypeVar('S', bound=Serializable)


class SerializableParameter(luigi.Parameter, Generic[S]):
def __init__(self, object_type: type[S], *args, **kwargs):
self._object_type = object_type
super().__init__(*args, **kwargs)

def parse(self, s: str) -> S:
return self._object_type.gokart_deserialize(s)

def serialize(self, x: S) -> str:
return x.gokart_serialize()
83 changes: 83 additions & 0 deletions test/test_serializable_parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import json
import tempfile
from dataclasses import asdict, dataclass

import luigi
import pytest
from luigi.cmdline_parser import CmdlineParser
from mypy import api

from gokart import SerializableParameter, TaskOnKart
from test.config import PYPROJECT_TOML


@dataclass(frozen=True)
class Config:
foo: int
bar: str

def gokart_serialize(self) -> str:
# dict is ordered in Python 3.7+
return json.dumps(asdict(self))

@classmethod
def gokart_deserialize(cls, s: str) -> 'Config':
return cls(**json.loads(s))


class SerializableParameterWithOutDefault(TaskOnKart):
task_namespace = __name__
config: Config = SerializableParameter(object_type=Config)

def run(self):
self.dump(self.config)


class SerializableParameterWithDefault(TaskOnKart):
task_namespace = __name__
config: Config = SerializableParameter(object_type=Config, default=Config(foo=1, bar='bar'))

def run(self):
self.dump(self.config)


class TestSerializableParameter:
def test_default(self):
with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithDefault']) as cp:
assert cp.get_task_obj().config == Config(foo=1, bar='bar')

def test_parse_param(self):
with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault', '--config', '{"foo": 100, "bar": "val"}']) as cp:
assert cp.get_task_obj().config == Config(foo=100, bar='val')

def test_missing_parameter(self):
with pytest.raises(luigi.parameter.MissingParameterException):
with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault']) as cp:
cp.get_task_obj()

def test_value_error(self):
with pytest.raises(ValueError):
with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault', '--config', 'Foo']) as cp:
cp.get_task_obj()

def test_expected_one_argument_error(self):
with pytest.raises(SystemExit):
with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault', '--config']) as cp:
cp.get_task_obj()

def test_mypy(self):
"""check invalid object cannot used for SerializableParameter"""

test_code = """
import gokart
class InvalidClass:
...
gokart.SerializableParameter(object_type=InvalidClass)
"""
with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
test_file.write(test_code.encode('utf-8'))
test_file.flush()
result = api.run(['--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
assert 'Value of type variable "S" of "SerializableParameter" cannot be "InvalidClass" [type-var]' in result[0]

0 comments on commit 15737f6

Please sign in to comment.