Skip to content

Commit

Permalink
Process outputs (#911)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Aug 6, 2024
1 parent 0badaa0 commit 01d5622
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 15 deletions.
40 changes: 29 additions & 11 deletions python/langsmith/run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Mapping,
Optional,
Protocol,
Sequence,
Tuple,
Type,
TypedDict,
Expand Down Expand Up @@ -242,9 +243,10 @@ def traceable(
metadata: Optional[Mapping[str, Any]] = None,
tags: Optional[List[str]] = None,
client: Optional[ls_client.Client] = None,
reduce_fn: Optional[Callable] = None,
reduce_fn: Optional[Callable[[Sequence], dict]] = None,
project_name: Optional[str] = None,
process_inputs: Optional[Callable[[dict], dict]] = None,
process_outputs: Optional[Callable[..., dict]] = None,
_invocation_params_fn: Optional[Callable[[dict], dict]] = None,
) -> Callable[[Callable[P, R]], SupportsLangsmithExtra[P, R]]: ...

Expand All @@ -270,7 +272,11 @@ def traceable(
called, and the run itself will be stuck in a pending state.
project_name: The name of the project to log the run to. Defaults to None,
which will use the default project.
process_inputs: A function to filter the inputs to the run. Defaults to None.
process_inputs: Custom serialization / processing function for inputs.
Defaults to None.
process_outputs: Custom serialization / processing function for outputs.
Defaults to None.
Returns:
Expand Down Expand Up @@ -415,6 +421,18 @@ def manual_extra_function(x):
process_inputs=kwargs.pop("process_inputs", None),
invocation_params_fn=kwargs.pop("_invocation_params_fn", None),
)
outputs_processor = kwargs.pop("process_outputs", None)

def _on_run_end(
container: _TraceableContainer,
outputs: Optional[Any] = None,
error: Optional[BaseException] = None,
) -> None:
"""Handle the end of run."""
if outputs and outputs_processor is not None:
outputs = outputs_processor(outputs)
_container_end(container, outputs=outputs, error=error)

if kwargs:
warnings.warn(
f"The following keyword arguments are not recognized and will be ignored: "
Expand Down Expand Up @@ -463,11 +481,11 @@ async def async_wrapper(
except BaseException as e:
# shield from cancellation, given we're catching all exceptions
await asyncio.shield(
aitertools.aio_to_thread(_container_end, run_container, error=e)
aitertools.aio_to_thread(_on_run_end, run_container, error=e)
)
raise e
await aitertools.aio_to_thread(
_container_end, run_container, outputs=function_result
_on_run_end, run_container, outputs=function_result
)
return function_result

Expand Down Expand Up @@ -536,7 +554,7 @@ async def async_generator_wrapper(
pass
except BaseException as e:
await asyncio.shield(
aitertools.aio_to_thread(_container_end, run_container, error=e)
aitertools.aio_to_thread(_on_run_end, run_container, error=e)
)
raise e
if results:
Expand All @@ -551,7 +569,7 @@ async def async_generator_wrapper(
else:
function_result = None
await aitertools.aio_to_thread(
_container_end, run_container, outputs=function_result
_on_run_end, run_container, outputs=function_result
)

@functools.wraps(func)
Expand All @@ -578,9 +596,9 @@ def wrapper(
kwargs.pop("config", None)
function_result = run_container["context"].run(func, *args, **kwargs)
except BaseException as e:
_container_end(run_container, error=e)
_on_run_end(run_container, error=e)
raise e
_container_end(run_container, outputs=function_result)
_on_run_end(run_container, outputs=function_result)
return function_result

@functools.wraps(func)
Expand Down Expand Up @@ -630,7 +648,7 @@ def generator_wrapper(
pass

except BaseException as e:
_container_end(run_container, error=e)
_on_run_end(run_container, error=e)
raise e
if results:
if reduce_fn:
Expand All @@ -643,7 +661,7 @@ def generator_wrapper(
function_result = results
else:
function_result = None
_container_end(run_container, outputs=function_result)
_on_run_end(run_container, outputs=function_result)

if inspect.isasyncgenfunction(func):
selected_wrapper: Callable = async_generator_wrapper
Expand Down Expand Up @@ -1131,7 +1149,7 @@ def _container_end(
container: _TraceableContainer,
outputs: Optional[Any] = None,
error: Optional[BaseException] = None,
):
) -> None:
"""End the run."""
run_tree = container.get("new_run")
if run_tree is None:
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langsmith"
version = "0.1.97"
version = "0.1.98"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
authors = ["LangChain <support@langchain.dev>"]
license = "MIT"
Expand Down
6 changes: 3 additions & 3 deletions python/tests/integration_tests/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import uuid
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import AsyncGenerator, Generator, Optional
from typing import AsyncGenerator, Generator, Optional, Sequence

import pytest # type: ignore

Expand Down Expand Up @@ -330,7 +330,7 @@ def test_sync_generator_reduce_fn(langchain_client: Client):
project_name = "__My Tracer Project - test_sync_generator_reduce_fn"
run_meta = uuid.uuid4().hex

def reduce_fn(outputs: list) -> dict:
def reduce_fn(outputs: Sequence) -> dict:
return {"my_output": " ".join(outputs)}

@traceable(run_type="chain", reduce_fn=reduce_fn)
Expand Down Expand Up @@ -411,7 +411,7 @@ async def test_async_generator_reduce_fn(langchain_client: Client):
project_name = "__My Tracer Project - test_async_generator_reduce_fn"
run_meta = uuid.uuid4().hex

def reduce_fn(outputs: list) -> dict:
def reduce_fn(outputs: Sequence) -> dict:
return {"my_output": " ".join(outputs)}

@traceable(run_type="chain", reduce_fn=reduce_fn)
Expand Down
107 changes: 107 additions & 0 deletions python/tests/unit_tests/test_run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,3 +1341,110 @@ async def test_trace_respects_env_var(env_var: bool, context: Optional[bool]):
assert len(mock_calls) >= 1
else:
assert not mock_calls


async def test_process_inputs_outputs():
mock_client = _get_mock_client()
in_s = "what's life's meaning"

def process_inputs(inputs: dict) -> dict:
assert inputs == {"val": in_s, "ooblek": "nada"}
inputs["val2"] = "this is mutated"
return {"serialized_in": "what's the meaning of life?"}

def process_outputs(outputs: int) -> dict:
assert outputs == 42
return {"serialized_out": 24}

@traceable(process_inputs=process_inputs, process_outputs=process_outputs)
def my_function(val: str, **kwargs: Any) -> int:
assert not kwargs.get("val2")
return 42

with tracing_context(enabled=True):
my_function(
in_s,
ooblek="nada",
langsmith_extra={"client": mock_client},
)

def _check_client(client: Client) -> None:
mock_calls = _get_calls(client)
assert len(mock_calls) == 1
call = mock_calls[0]
assert call.args[0] == "POST"
assert call.args[1].startswith("https://api.smith.langchain.com")
body = json.loads(call.kwargs["data"])
assert body["post"]
assert body["post"][0]["inputs"] == {
"serialized_in": "what's the meaning of life?"
}
assert body["post"][0]["outputs"] == {"serialized_out": 24}

_check_client(mock_client)

@traceable(process_inputs=process_inputs, process_outputs=process_outputs)
async def amy_function(val: str, **kwargs: Any) -> int:
assert not kwargs.get("val2")
return 42

mock_client = _get_mock_client()
with tracing_context(enabled=True):
await amy_function(
in_s,
ooblek="nada",
langsmith_extra={"client": mock_client},
)

_check_client(mock_client)

# Do generator

def reducer(outputs: list) -> dict:
return {"reduced": outputs[0]}

def process_reduced_outputs(outputs: dict) -> dict:
assert outputs == {"reduced": 42}
return {"serialized_out": 24}

@traceable(
process_inputs=process_inputs,
process_outputs=process_reduced_outputs,
reduce_fn=reducer,
)
def my_gen(val: str, **kwargs: Any) -> Generator[int, None, None]:
assert not kwargs.get("val2")
yield 42

mock_client = _get_mock_client()
with tracing_context(enabled=True):
result = list(
my_gen(
in_s,
ooblek="nada",
langsmith_extra={"client": mock_client},
)
)
assert result == [42]

_check_client(mock_client)

@traceable(
process_inputs=process_inputs,
process_outputs=process_reduced_outputs,
reduce_fn=reducer,
)
async def amy_gen(val: str, **kwargs: Any) -> AsyncGenerator[int, None]:
assert not kwargs.get("val2")
yield 42

mock_client = _get_mock_client()
with tracing_context(enabled=True):
result = [
i
async for i in amy_gen(
in_s, ooblek="nada", langsmith_extra={"client": mock_client}
)
]
assert result == [42]
_check_client(mock_client)

0 comments on commit 01d5622

Please sign in to comment.