Skip to content

Commit

Permalink
feature: support camelcase args (#50)
Browse files Browse the repository at this point in the history
* feature: add `to_lower_snake_case` to utils

* feature: support camelCase task args

* fix: outdated test
  • Loading branch information
ClemDoum authored Nov 14, 2024
1 parent 1f85c6d commit f2bf7da
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 20 deletions.
7 changes: 7 additions & 0 deletions icij-common/icij_common/pydantic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ def to_lower_camel(field: str) -> str:
)


def to_lower_snake_case(s: str):
snake = "".join("_" + c.lower() if c.isupper() else c for c in s)
if snake.startswith("_"):
snake = snake[1:]
return snake


_FIELD_ARGS = ["include", "exclude", "update"]

_SCHEMAS = dict()
Expand Down
14 changes: 0 additions & 14 deletions icij-common/icij_common/tests/test_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,6 @@ async def test_es_client_should_search_with_pagination_size():
)


async def test_es_client_should_raise_when_size_is_provided():
# Given
pagination = 666
es_client = ESClient(pagination=pagination)
size = 100
body = None
index = "test-datashare-project"

# When/Then
expected_msg = "ESClient run searches using the pagination_size"
with pytest.raises(ValueError, match=expected_msg):
await es_client.search(body=body, index=index, size=size)


class _MockFailingClient(MockedESClient, metaclass=abc.ABCMeta):
def __init__(
self,
Expand Down
25 changes: 21 additions & 4 deletions icij-worker/icij_worker/tests/worker/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from icij_worker.exceptions import TaskAlreadyCancelled, WorkerTimeoutError
from icij_worker.objects import ErrorEvent, ProgressEvent
from icij_worker.utils.tests import MockManager, MockWorker
from icij_worker.worker.worker import add_missing_args
from icij_worker.worker.worker import add_missing_args, task_wrapper


@pytest.fixture(
Expand Down Expand Up @@ -192,7 +192,7 @@ async def _has_state(state: TaskState) -> bool:
assert saved_result == expected_result


async def test_task_wrapper_should_recover_from_recoverable_error(
async def test_worker_should_recover_from_recoverable_error(
mock_failing_worker: MockWorker,
):
# Given
Expand Down Expand Up @@ -309,7 +309,7 @@ async def _has_state(state: TaskState) -> bool:
assert events == expected_events


async def test_task_wrapper_should_handle_fatal_error(mock_failing_worker: MockWorker):
async def test_worker_should_handle_fatal_error(mock_failing_worker: MockWorker):
# Given
worker = mock_failing_worker
task_manager = MockManager(worker.app, worker.db_path)
Expand Down Expand Up @@ -385,7 +385,7 @@ async def _has_state(state: TaskState) -> bool:
assert error_event == expected_error_event


async def test_task_wrapper_should_handle_unregistered_task(mock_worker: MockWorker):
async def test_worker_should_handle_unregistered_task(mock_worker: MockWorker):
# Given
worker = mock_worker
task_manager = MockManager(worker.app, worker.db_path)
Expand Down Expand Up @@ -702,3 +702,20 @@ async def _assert_has_state(state: TaskState) -> bool:
error = errors[0].error
assert error.name == WorkerTimeoutError.__name__
t.cancel()


@pytest.mark.parametrize(
"mock_worker", [{"app": "test_async_app"}], indirect=["mock_worker"]
)
async def test_task_wrapper_should_handle_camel_case_args(mock_worker: MockWorker):
# Given
worker = mock_worker
args = {"snakeCaseArg": "Imma snake"}
task = Task.create(task_id="task_id", task_name="case_test_task", args=args)

# When
worker._current = task # pylint: disable=protected-access
task = await task_wrapper(worker, task)

# Then
assert task.state == TaskState.DONE
4 changes: 4 additions & 0 deletions icij-worker/icij_worker/utils/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ async def sleep_for_short(
async def often_retriable() -> str:
pass

@APP.task
def case_test_task(snake_case_arg: str):
return snake_case_arg

@pytest.fixture(scope="session")
def test_async_app() -> AsyncApp:
return AsyncApp.load(f"{__name__}.APP")
Expand Down
5 changes: 3 additions & 2 deletions icij-worker/icij_worker/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from typing_extensions import Self

from icij_common.pydantic_utils import safe_copy
from icij_common.pydantic_utils import safe_copy, to_lower_snake_case
from icij_worker import AsyncApp, ResultEvent, Task, TaskError, TaskState
from icij_worker.app import RegisteredTask, supports_progress
from icij_worker.event_publisher.event_publisher import EventPublisher
Expand Down Expand Up @@ -585,7 +585,8 @@ async def task_wrapper(worker: Worker, task: Task) -> Task:
raise TaskAlreadyCancelled(task_id=task.id)
# Parse task to retrieve recoverable errors and max retries
task_fn, recoverable_errors = worker.parse_task(task)
task_inputs = add_missing_args(task_fn, task.args)
task_args = {to_lower_snake_case(k): v for k, v in task.args.items()}
task_inputs = add_missing_args(task_fn, task_args)
# Retry task until success, fatal error or max retry exceeded
return await _retry_task(worker, task, task_fn, task_inputs, recoverable_errors)

Expand Down

0 comments on commit f2bf7da

Please sign in to comment.