Skip to content

Commit

Permalink
🌿 manually modify start job from local file to match server expectations
Browse files Browse the repository at this point in the history
  • Loading branch information
fern-api[bot] committed Sep 30, 2024
1 parent 493fc8d commit 8de2de7
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 7 deletions.
131 changes: 130 additions & 1 deletion src/hume/expression_measurement/batch/client_with_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import aiofiles
import typing
import json as jsonlib
from json.decoder import JSONDecodeError

from ...core.request_options import RequestOptions
from ...core.jsonable_encoder import jsonable_encoder
from ... import core

from .types.inference_base_request import InferenceBaseRequest
from ...core.pydantic_utilities import parse_obj_as
from .types.job_id import JobId
from .client import AsyncBatchClient, BatchClient
from ...core.api_error import ApiError

class BatchClientWithUtils(BatchClient):
def get_and_write_job_artifacts(
Expand Down Expand Up @@ -47,6 +55,66 @@ def get_and_write_job_artifacts(
for chunk in self.get_job_artifacts(id=id, request_options=request_options):
f.write(chunk)

def start_inference_job_from_local_file(
self,
*,
file: typing.List[core.File],
json: typing.Optional[InferenceBaseRequest] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> str:
"""
Start a new batch inference job.
Parameters
----------
file : typing.List[core.File]
See core.File for more documentation
json : typing.Optional[InferenceBaseRequest]
The inference job configuration.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
str
Examples
--------
from hume import HumeClient
client = HumeClient(
api_key="YOUR_API_KEY",
)
client.expression_measurement.batch.start_inference_job_from_local_file()
"""
_response = self._client_wrapper.httpx_client.request(
"v0/batch/jobs",
method="POST",
files={
"file": file,
"json": jsonlib.dumps(jsonable_encoder(json)).encode("utf-8"),
},
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_parsed_response = typing.cast(
JobId,
parse_obj_as(
type_=JobId, # type: ignore
object_=_response.json(),
),
)
return _parsed_response.job_id
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, body=_response.text)
raise ApiError(status_code=_response.status_code, body=_response_json)


class AsyncBatchClientWithUtils(AsyncBatchClient):
async def get_and_write_job_artifacts(
self,
Expand Down Expand Up @@ -87,4 +155,65 @@ async def get_and_write_job_artifacts(
"""
async with aiofiles.open(file_name, mode='wb') as f:
async for chunk in self.get_job_artifacts(id=id, request_options=request_options):
await f.write(chunk)
await f.write(chunk)

async def start_inference_job_from_local_file(
self,
*,
file: typing.List[core.File],
json: typing.Optional[InferenceBaseRequest] = None,
request_options: typing.Optional[RequestOptions] = None,
) -> str:
"""
Start a new batch inference job.
Parameters
----------
file : typing.List[core.File]
See core.File for more documentation
json : typing.Optional[InferenceBaseRequest]
The inference job configuration.
request_options : typing.Optional[RequestOptions]
Request-specific configuration.
Returns
-------
str
Examples
--------
from hume import HumeClient
client = HumeClient(
api_key="YOUR_API_KEY",
)
client.expression_measurement.batch.start_inference_job_from_local_file()
"""
_response = await self._client_wrapper.httpx_client.request(
"v0/batch/jobs",
method="POST",
data={
"json": jsonlib.dumps(jsonable_encoder(json)).encode("utf-8"),
},
files={
"file": file,
},
request_options=request_options,
)
try:
if 200 <= _response.status_code < 300:
_parsed_response = typing.cast(
JobId,
parse_obj_as(
type_=JobId, # type: ignore
object_=_response.json(),
),
)
return _parsed_response.job_id
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, body=_response.text)
raise ApiError(status_code=_response.status_code, body=_response_json)
23 changes: 17 additions & 6 deletions tests/custom/test_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import pytest
import aiofiles

from hume.client import AsyncHumeClient
from hume.client import AsyncHumeClient, HumeClient
from hume.expression_measurement.batch.types.face import Face
from hume.expression_measurement.batch.types.inference_base_request import InferenceBaseRequest
from hume.expression_measurement.batch.types.models import Models

# Get started with writing tests with pytest at https://docs.pytest.org
@pytest.mark.skip(reason="Unimplemented")
def test_client() -> None:
assert True == True

@pytest.mark.skip(reason="CI does not have authentication.")
async def test_write_job_artifacts() -> None:
Expand All @@ -20,4 +19,16 @@ async def test_get_job_predictions() -> None:
client = AsyncHumeClient(api_key="MY_API_KEY")
await client.expression_measurement.batch.get_job_predictions(id="my-job-id", request_options={
"max_retries": 3,
})
})

# @pytest.mark.skip(reason="CI does not have authentication.")
async def test_start_inference_job_from_local_file() -> None:
client = HumeClient(api_key="MY_API_KEY")
client.expression_measurement.batch.start_inference_job_from_local_file(
file=[],
json=InferenceBaseRequest(
models=Models(
face=Face()
)
)
)

0 comments on commit 8de2de7

Please sign in to comment.