Skip to content

Commit

Permalink
fix: DatashareTaskClient auth
Browse files Browse the repository at this point in the history
  • Loading branch information
ClemDoum committed Dec 9, 2024
1 parent 096497d commit 01b06ef
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 30 deletions.
28 changes: 14 additions & 14 deletions icij-worker/icij_worker/ds_task_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,20 @@ async def create_task(
*,
id_: Optional[str] = None,
group: Optional[str] = None,
) -> Task:
) -> str:
if id_ is None:
id_ = _generate_task_id(name)
task = Task.create(task_id=id_, task_name=name, args=args)
task = jsonable_encoder(task, exclude=_TASK_UNSUPPORTED, exclude_unset=True)
task.pop("createdAt")
url = f"/api/task/{id_}"
# TODO: we shouldn't have to write stuf like org.icij.datashare.asynctasks.Group
# here
data = {"task": task, "group": _make_java_group(group)}
async with self._put(url, json=data) as res:
task = await res.json()
task = Task(**task)
return task
if group is not None:
if not isinstance(group, str):
raise TypeError(f"expected group to be a string found {group}")
url += f"?group={group}"
async with self._put(url, json=task) as res:
task_res = await res.json()
return task_res["taskId"]

async def get_task(self, id_: str) -> Task:
url = f"/api/task/{id_}"
Expand All @@ -62,7 +62,7 @@ async def get_tasks(self) -> list[Task]:
async def get_task_state(self, id_: str) -> TaskState:
return (await self.get_task(id_)).state

async def get_task_result(self, id_: str) -> object:
async def get_task_result(self, id_: str) -> Any:
# TODO: we probably want to use /api/task/:id/results instead but it's
# restricted, we might need an API key or some auth
url = f"/api/task/{id_}"
Expand All @@ -84,10 +84,10 @@ def _generate_task_id(task_name: str) -> str:
return f"{task_name}-{uuid.uuid4()}"


def _ds_to_icij_worker_task(task: dict) -> dict:
task.pop("result", None)
return task
_JAVA_TASK_ATTRIBUTES = ["result", "error"]


def _make_java_group(group: str) -> dict:
return {"@type": "org.icij.datashare.asynctasks.Group", "id": group}
def _ds_to_icij_worker_task(task: dict) -> dict:
for k in _JAVA_TASK_ATTRIBUTES:
task.pop(k, None)
return task
Original file line number Diff line number Diff line change
Expand Up @@ -10,45 +10,43 @@
from icij_worker.ds_task_client import DatashareTaskClient


async def test_task_client_create_task(monkeypatch):
async def test_ds_task_client_create_task(monkeypatch):
# Given
datashare_url = "http://some-url"
task_name = "hello"
task_id = f"{task_name}-{uuid.uuid4()}"
args = {"greeted": "world"}
group = "PYTHON"

@asynccontextmanager
async def _put_and_assert(_, url: StrOrURL, *, data: Any = None, **kwargs: Any):
assert url == f"/api/task/{task_id}"
assert url == f"/api/task/{task_id}?group={group}"
expected_task = {
"@type": "Task",
"id": task_id,
"state": "CREATED",
"name": "hello",
"args": {"greeted": "world"},
}
expected_data = {"task": expected_task, "group": "PYTHON"}
expected_data = expected_task
assert data is None
json_data = kwargs.pop("json")
assert not kwargs
assert json_data == expected_data
expected_task["createdAt"] = datetime.now()
mocked_res = AsyncMock()
mocked_res.json.return_value = expected_task
mocked_res.json.return_value = {"taskId": task_id}
yield mocked_res

monkeypatch.setattr("icij_worker.utils.http.AiohttpClient._put", _put_and_assert)

task_client = DatashareTaskClient(datashare_url)
async with task_client:
# When
task = await task_client.create_task(
task_name, args, id_=task_id, group="PYTHON"
)
assert isinstance(task, Task)
t_id = await task_client.create_task(task_name, args, id_=task_id, group=group)
assert t_id == task_id


async def test_task_client_get_task(monkeypatch):
async def test_ds_task_client_get_task(monkeypatch):
# Given
datashare_url = "http://some-url"
task_name = "hello"
Expand Down Expand Up @@ -82,7 +80,7 @@ async def _get_and_assert(
assert isinstance(task, Task)


async def test_task_client_get_task_state(monkeypatch):
async def test_ds_task_client_get_task_state(monkeypatch):
# Given
datashare_url = "http://some-url"
task_name = "hello"
Expand Down Expand Up @@ -118,7 +116,7 @@ async def _get_and_assert(
assert res == TaskState.DONE


async def test_task_client_get_task_result(monkeypatch):
async def test_ds_task_client_get_task_result(monkeypatch):
# Given
datashare_url = "http://some-url"
task_name = "hello"
Expand Down
23 changes: 19 additions & 4 deletions icij-worker/icij_worker/utils/http.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from contextlib import asynccontextmanager
from copy import deepcopy
from typing import Any, AsyncContextManager, Optional

from aiohttp import BasicAuth, ClientResponse, ClientResponseError, ClientSession
Expand All @@ -11,10 +12,18 @@


class AiohttpClient(AsyncContextManager):
def __init__(self, base_url: str, auth: Optional[BasicAuth] = None):
def __init__(
self,
base_url: str,
auth: Optional[BasicAuth] = None,
headers: dict | None = None,
):
self._base_url = base_url
self._auth = auth
self._session: Optional[ClientSession] = None
if headers is None:
headers = dict()
self._headers = headers

async def __aenter__(self):
self._session = ClientSession(self._base_url, auth=self._auth)
Expand All @@ -26,21 +35,27 @@ async def __aexit__(self, exc_type, exc_value, traceback):

@asynccontextmanager
async def _put(self, url: StrOrURL, *, data: Any = None, **kwargs: Any):
async with self._session.put(url, data=data, **kwargs) as res:
headers = deepcopy(self._headers)
headers.update(kwargs.pop("headers", dict()))
async with self._session.put(url, data=data, headers=headers, **kwargs) as res:
_raise_for_status(res)
yield res

@asynccontextmanager
async def _get(self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any):
headers = deepcopy(self._headers)
headers.update(kwargs.pop("headers", dict()))
async with self._session.get(
url, allow_redirects=allow_redirects, **kwargs
url, headers=headers, allow_redirects=allow_redirects, **kwargs
) as res:
_raise_for_status(res)
yield res

@asynccontextmanager
async def _delete(self, url: StrOrURL, **kwargs: Unpack[_RequestOptions]):
async with self._session.delete(url, **kwargs) as res:
headers = deepcopy(self._headers)
headers.update(kwargs.pop("headers", dict()))
async with self._session.delete(url, headers=headers, **kwargs) as res:
_raise_for_status(res)
yield res

Expand Down

0 comments on commit 01b06ef

Please sign in to comment.