From 01d5622b88d16104d28de6908e6d93e65d392bd3 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Mon, 5 Aug 2024 22:34:00 -0700 Subject: [PATCH] Process outputs (#911) --- python/langsmith/run_helpers.py | 40 ++++++-- python/pyproject.toml | 2 +- python/tests/integration_tests/test_runs.py | 6 +- python/tests/unit_tests/test_run_helpers.py | 107 ++++++++++++++++++++ 4 files changed, 140 insertions(+), 15 deletions(-) diff --git a/python/langsmith/run_helpers.py b/python/langsmith/run_helpers.py index 90301c03d..55d9d1ad0 100644 --- a/python/langsmith/run_helpers.py +++ b/python/langsmith/run_helpers.py @@ -25,6 +25,7 @@ Mapping, Optional, Protocol, + Sequence, Tuple, Type, TypedDict, @@ -242,9 +243,10 @@ def traceable( metadata: Optional[Mapping[str, Any]] = None, tags: Optional[List[str]] = None, client: Optional[ls_client.Client] = None, - reduce_fn: Optional[Callable] = None, + reduce_fn: Optional[Callable[[Sequence], dict]] = None, project_name: Optional[str] = None, process_inputs: Optional[Callable[[dict], dict]] = None, + process_outputs: Optional[Callable[..., dict]] = None, _invocation_params_fn: Optional[Callable[[dict], dict]] = None, ) -> Callable[[Callable[P, R]], SupportsLangsmithExtra[P, R]]: ... @@ -270,7 +272,11 @@ def traceable( called, and the run itself will be stuck in a pending state. project_name: The name of the project to log the run to. Defaults to None, which will use the default project. - process_inputs: A function to filter the inputs to the run. Defaults to None. + process_inputs: Custom serialization / processing function for inputs. + Defaults to None. + process_outputs: Custom serialization / processing function for outputs. + Defaults to None. + Returns: @@ -415,6 +421,18 @@ def manual_extra_function(x): process_inputs=kwargs.pop("process_inputs", None), invocation_params_fn=kwargs.pop("_invocation_params_fn", None), ) + outputs_processor = kwargs.pop("process_outputs", None) + + def _on_run_end( + container: _TraceableContainer, + outputs: Optional[Any] = None, + error: Optional[BaseException] = None, + ) -> None: + """Handle the end of run.""" + if outputs and outputs_processor is not None: + outputs = outputs_processor(outputs) + _container_end(container, outputs=outputs, error=error) + if kwargs: warnings.warn( f"The following keyword arguments are not recognized and will be ignored: " @@ -463,11 +481,11 @@ async def async_wrapper( except BaseException as e: # shield from cancellation, given we're catching all exceptions await asyncio.shield( - aitertools.aio_to_thread(_container_end, run_container, error=e) + aitertools.aio_to_thread(_on_run_end, run_container, error=e) ) raise e await aitertools.aio_to_thread( - _container_end, run_container, outputs=function_result + _on_run_end, run_container, outputs=function_result ) return function_result @@ -536,7 +554,7 @@ async def async_generator_wrapper( pass except BaseException as e: await asyncio.shield( - aitertools.aio_to_thread(_container_end, run_container, error=e) + aitertools.aio_to_thread(_on_run_end, run_container, error=e) ) raise e if results: @@ -551,7 +569,7 @@ async def async_generator_wrapper( else: function_result = None await aitertools.aio_to_thread( - _container_end, run_container, outputs=function_result + _on_run_end, run_container, outputs=function_result ) @functools.wraps(func) @@ -578,9 +596,9 @@ def wrapper( kwargs.pop("config", None) function_result = run_container["context"].run(func, *args, **kwargs) except BaseException as e: - _container_end(run_container, error=e) + _on_run_end(run_container, error=e) raise e - _container_end(run_container, outputs=function_result) + _on_run_end(run_container, outputs=function_result) return function_result @functools.wraps(func) @@ -630,7 +648,7 @@ def generator_wrapper( pass except BaseException as e: - _container_end(run_container, error=e) + _on_run_end(run_container, error=e) raise e if results: if reduce_fn: @@ -643,7 +661,7 @@ def generator_wrapper( function_result = results else: function_result = None - _container_end(run_container, outputs=function_result) + _on_run_end(run_container, outputs=function_result) if inspect.isasyncgenfunction(func): selected_wrapper: Callable = async_generator_wrapper @@ -1131,7 +1149,7 @@ def _container_end( container: _TraceableContainer, outputs: Optional[Any] = None, error: Optional[BaseException] = None, -): +) -> None: """End the run.""" run_tree = container.get("new_run") if run_tree is None: diff --git a/python/pyproject.toml b/python/pyproject.toml index ff387d75d..27d00542f 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langsmith" -version = "0.1.97" +version = "0.1.98" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." authors = ["LangChain "] license = "MIT" diff --git a/python/tests/integration_tests/test_runs.py b/python/tests/integration_tests/test_runs.py index fbf87ea92..c9b62661e 100644 --- a/python/tests/integration_tests/test_runs.py +++ b/python/tests/integration_tests/test_runs.py @@ -3,7 +3,7 @@ import uuid from collections import defaultdict from concurrent.futures import ThreadPoolExecutor -from typing import AsyncGenerator, Generator, Optional +from typing import AsyncGenerator, Generator, Optional, Sequence import pytest # type: ignore @@ -330,7 +330,7 @@ def test_sync_generator_reduce_fn(langchain_client: Client): project_name = "__My Tracer Project - test_sync_generator_reduce_fn" run_meta = uuid.uuid4().hex - def reduce_fn(outputs: list) -> dict: + def reduce_fn(outputs: Sequence) -> dict: return {"my_output": " ".join(outputs)} @traceable(run_type="chain", reduce_fn=reduce_fn) @@ -411,7 +411,7 @@ async def test_async_generator_reduce_fn(langchain_client: Client): project_name = "__My Tracer Project - test_async_generator_reduce_fn" run_meta = uuid.uuid4().hex - def reduce_fn(outputs: list) -> dict: + def reduce_fn(outputs: Sequence) -> dict: return {"my_output": " ".join(outputs)} @traceable(run_type="chain", reduce_fn=reduce_fn) diff --git a/python/tests/unit_tests/test_run_helpers.py b/python/tests/unit_tests/test_run_helpers.py index 4c960ddf8..f749dc17a 100644 --- a/python/tests/unit_tests/test_run_helpers.py +++ b/python/tests/unit_tests/test_run_helpers.py @@ -1341,3 +1341,110 @@ async def test_trace_respects_env_var(env_var: bool, context: Optional[bool]): assert len(mock_calls) >= 1 else: assert not mock_calls + + +async def test_process_inputs_outputs(): + mock_client = _get_mock_client() + in_s = "what's life's meaning" + + def process_inputs(inputs: dict) -> dict: + assert inputs == {"val": in_s, "ooblek": "nada"} + inputs["val2"] = "this is mutated" + return {"serialized_in": "what's the meaning of life?"} + + def process_outputs(outputs: int) -> dict: + assert outputs == 42 + return {"serialized_out": 24} + + @traceable(process_inputs=process_inputs, process_outputs=process_outputs) + def my_function(val: str, **kwargs: Any) -> int: + assert not kwargs.get("val2") + return 42 + + with tracing_context(enabled=True): + my_function( + in_s, + ooblek="nada", + langsmith_extra={"client": mock_client}, + ) + + def _check_client(client: Client) -> None: + mock_calls = _get_calls(client) + assert len(mock_calls) == 1 + call = mock_calls[0] + assert call.args[0] == "POST" + assert call.args[1].startswith("https://api.smith.langchain.com") + body = json.loads(call.kwargs["data"]) + assert body["post"] + assert body["post"][0]["inputs"] == { + "serialized_in": "what's the meaning of life?" + } + assert body["post"][0]["outputs"] == {"serialized_out": 24} + + _check_client(mock_client) + + @traceable(process_inputs=process_inputs, process_outputs=process_outputs) + async def amy_function(val: str, **kwargs: Any) -> int: + assert not kwargs.get("val2") + return 42 + + mock_client = _get_mock_client() + with tracing_context(enabled=True): + await amy_function( + in_s, + ooblek="nada", + langsmith_extra={"client": mock_client}, + ) + + _check_client(mock_client) + + # Do generator + + def reducer(outputs: list) -> dict: + return {"reduced": outputs[0]} + + def process_reduced_outputs(outputs: dict) -> dict: + assert outputs == {"reduced": 42} + return {"serialized_out": 24} + + @traceable( + process_inputs=process_inputs, + process_outputs=process_reduced_outputs, + reduce_fn=reducer, + ) + def my_gen(val: str, **kwargs: Any) -> Generator[int, None, None]: + assert not kwargs.get("val2") + yield 42 + + mock_client = _get_mock_client() + with tracing_context(enabled=True): + result = list( + my_gen( + in_s, + ooblek="nada", + langsmith_extra={"client": mock_client}, + ) + ) + assert result == [42] + + _check_client(mock_client) + + @traceable( + process_inputs=process_inputs, + process_outputs=process_reduced_outputs, + reduce_fn=reducer, + ) + async def amy_gen(val: str, **kwargs: Any) -> AsyncGenerator[int, None]: + assert not kwargs.get("val2") + yield 42 + + mock_client = _get_mock_client() + with tracing_context(enabled=True): + result = [ + i + async for i in amy_gen( + in_s, ooblek="nada", langsmith_extra={"client": mock_client} + ) + ] + assert result == [42] + _check_client(mock_client)