From 546a36f4ae283d43df4ca7b0590eda5ba12206b0 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Tue, 30 Jul 2024 05:09:53 -0700 Subject: [PATCH] [Python] Fix nesting in async trace context manager (#895) We were previously setting the context vars in a separate context and then letting it be gc'd. Fixes https://github.com/langchain-ai/langsmith-sdk/issues/892 --- python/langsmith/_internal/_aiter.py | 6 ++-- python/langsmith/run_helpers.py | 14 ++++++++-- python/pyproject.toml | 2 +- python/tests/unit_tests/test_run_helpers.py | 31 +++++++++++++++++++-- 4 files changed, 44 insertions(+), 9 deletions(-) diff --git a/python/langsmith/_internal/_aiter.py b/python/langsmith/_internal/_aiter.py index a2f0701a1..7ae217f68 100644 --- a/python/langsmith/_internal/_aiter.py +++ b/python/langsmith/_internal/_aiter.py @@ -310,7 +310,9 @@ def accepts_context(callable: Callable[..., Any]) -> bool: # Ported from Python 3.9+ to support Python 3.8 -async def aio_to_thread(func, /, *args, **kwargs): +async def aio_to_thread( + func, /, *args, __ctx: Optional[contextvars.Context] = None, **kwargs +): """Asynchronously run function *func* in a separate thread. Any *args and **kwargs supplied for this function are directly passed @@ -321,7 +323,7 @@ async def aio_to_thread(func, /, *args, **kwargs): Return a coroutine that can be awaited to get the eventual result of *func*. """ loop = asyncio.get_running_loop() - ctx = contextvars.copy_context() + ctx = __ctx or contextvars.copy_context() func_call = functools.partial(ctx.run, func, *args, **kwargs) return await loop.run_in_executor(None, func_call) diff --git a/python/langsmith/run_helpers.py b/python/langsmith/run_helpers.py index 1131400bd..41885796c 100644 --- a/python/langsmith/run_helpers.py +++ b/python/langsmith/run_helpers.py @@ -926,7 +926,11 @@ async def __aenter__(self) -> run_trees.RunTree: Returns: run_trees.RunTree: The newly created run. """ - return await aitertools.aio_to_thread(self._setup) + ctx = copy_context() + result = await aitertools.aio_to_thread(self._setup, __ctx=ctx) + # Set the context for the current thread + _set_tracing_context(get_tracing_context(ctx)) + return result async def __aexit__( self, @@ -941,14 +945,18 @@ async def __aexit__( exc_value: The exception instance that occurred, if any. traceback: The traceback object associated with the exception, if any. """ + ctx = copy_context() if exc_type is not None: await asyncio.shield( - aitertools.aio_to_thread(self._teardown, exc_type, exc_value, traceback) + aitertools.aio_to_thread( + self._teardown, exc_type, exc_value, traceback, __ctx=ctx + ) ) else: await aitertools.aio_to_thread( - self._teardown, exc_type, exc_value, traceback + self._teardown, exc_type, exc_value, traceback, __ctx=ctx ) + _set_tracing_context(get_tracing_context(ctx)) def _get_project_name(project_name: Optional[str]) -> Optional[str]: diff --git a/python/pyproject.toml b/python/pyproject.toml index f6f9fa609..dd9143861 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langsmith" -version = "0.1.93" +version = "0.1.94" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." authors = ["LangChain "] license = "MIT" diff --git a/python/tests/unit_tests/test_run_helpers.py b/python/tests/unit_tests/test_run_helpers.py index 4bbc182c9..d5be6c1dd 100644 --- a/python/tests/unit_tests/test_run_helpers.py +++ b/python/tests/unit_tests/test_run_helpers.py @@ -962,12 +962,25 @@ def _get_run(r: RunTree) -> None: async def test_traceable_to_atrace(): + @traceable + async def great_grandchild_fn(a: int, b: int) -> int: + return a + b + @traceable async def parent_fn(a: int, b: int) -> int: async with langsmith.trace( name="child_fn", inputs={"a": a, "b": b} ) as run_tree: - result = a + b + async with langsmith.trace( + "grandchild_fn", inputs={"a": a, "b": b, "c": "oh my"} + ) as run_tree_gc: + try: + async with langsmith.trace("expect_error", inputs={}): + raise ValueError("oh no") + except ValueError: + pass + result = await great_grandchild_fn(a, b) + run_tree_gc.end(outputs={"result": result}) run_tree.end(outputs={"result": result}) return result @@ -991,8 +1004,20 @@ def _get_run(r: RunTree) -> None: child_runs = run.child_runs assert child_runs assert len(child_runs) == 1 - assert child_runs[0].name == "child_fn" - assert child_runs[0].inputs == {"a": 1, "b": 2} + child = child_runs[0] + assert child.name == "child_fn" + assert child.inputs == {"a": 1, "b": 2} + assert len(child.child_runs) == 1 + grandchild = child.child_runs[0] + assert grandchild.name == "grandchild_fn" + assert grandchild.inputs == {"a": 1, "b": 2, "c": "oh my"} + assert len(grandchild.child_runs) == 2 + ggcerror = grandchild.child_runs[0] + assert ggcerror.name == "expect_error" + assert "oh no" in str(ggcerror.error) + ggc = grandchild.child_runs[1] + assert ggc.name == "great_grandchild_fn" + assert ggc.inputs == {"a": 1, "b": 2} def test_trace_to_traceable():