Skip to content

Commit

Permalink
[Python] Fix nesting in async trace context manager (#895)
Browse files Browse the repository at this point in the history
We were previously setting the context vars in a separate context and
then letting it be gc'd.

Fixes #892
  • Loading branch information
hinthornw authored Jul 30, 2024
1 parent 6ba2489 commit 546a36f
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 9 deletions.
6 changes: 4 additions & 2 deletions python/langsmith/_internal/_aiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,9 @@ def accepts_context(callable: Callable[..., Any]) -> bool:


# Ported from Python 3.9+ to support Python 3.8
async def aio_to_thread(func, /, *args, **kwargs):
async def aio_to_thread(
func, /, *args, __ctx: Optional[contextvars.Context] = None, **kwargs
):
"""Asynchronously run function *func* in a separate thread.
Any *args and **kwargs supplied for this function are directly passed
Expand All @@ -321,7 +323,7 @@ async def aio_to_thread(func, /, *args, **kwargs):
Return a coroutine that can be awaited to get the eventual result of *func*.
"""
loop = asyncio.get_running_loop()
ctx = contextvars.copy_context()
ctx = __ctx or contextvars.copy_context()
func_call = functools.partial(ctx.run, func, *args, **kwargs)
return await loop.run_in_executor(None, func_call)

Expand Down
14 changes: 11 additions & 3 deletions python/langsmith/run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,11 @@ async def __aenter__(self) -> run_trees.RunTree:
Returns:
run_trees.RunTree: The newly created run.
"""
return await aitertools.aio_to_thread(self._setup)
ctx = copy_context()
result = await aitertools.aio_to_thread(self._setup, __ctx=ctx)
# Set the context for the current thread
_set_tracing_context(get_tracing_context(ctx))
return result

async def __aexit__(
self,
Expand All @@ -941,14 +945,18 @@ async def __aexit__(
exc_value: The exception instance that occurred, if any.
traceback: The traceback object associated with the exception, if any.
"""
ctx = copy_context()
if exc_type is not None:
await asyncio.shield(
aitertools.aio_to_thread(self._teardown, exc_type, exc_value, traceback)
aitertools.aio_to_thread(
self._teardown, exc_type, exc_value, traceback, __ctx=ctx
)
)
else:
await aitertools.aio_to_thread(
self._teardown, exc_type, exc_value, traceback
self._teardown, exc_type, exc_value, traceback, __ctx=ctx
)
_set_tracing_context(get_tracing_context(ctx))


def _get_project_name(project_name: Optional[str]) -> Optional[str]:
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langsmith"
version = "0.1.93"
version = "0.1.94"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
authors = ["LangChain <support@langchain.dev>"]
license = "MIT"
Expand Down
31 changes: 28 additions & 3 deletions python/tests/unit_tests/test_run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,12 +962,25 @@ def _get_run(r: RunTree) -> None:


async def test_traceable_to_atrace():
@traceable
async def great_grandchild_fn(a: int, b: int) -> int:
return a + b

@traceable
async def parent_fn(a: int, b: int) -> int:
async with langsmith.trace(
name="child_fn", inputs={"a": a, "b": b}
) as run_tree:
result = a + b
async with langsmith.trace(
"grandchild_fn", inputs={"a": a, "b": b, "c": "oh my"}
) as run_tree_gc:
try:
async with langsmith.trace("expect_error", inputs={}):
raise ValueError("oh no")
except ValueError:
pass
result = await great_grandchild_fn(a, b)
run_tree_gc.end(outputs={"result": result})
run_tree.end(outputs={"result": result})
return result

Expand All @@ -991,8 +1004,20 @@ def _get_run(r: RunTree) -> None:
child_runs = run.child_runs
assert child_runs
assert len(child_runs) == 1
assert child_runs[0].name == "child_fn"
assert child_runs[0].inputs == {"a": 1, "b": 2}
child = child_runs[0]
assert child.name == "child_fn"
assert child.inputs == {"a": 1, "b": 2}
assert len(child.child_runs) == 1
grandchild = child.child_runs[0]
assert grandchild.name == "grandchild_fn"
assert grandchild.inputs == {"a": 1, "b": 2, "c": "oh my"}
assert len(grandchild.child_runs) == 2
ggcerror = grandchild.child_runs[0]
assert ggcerror.name == "expect_error"
assert "oh no" in str(ggcerror.error)
ggc = grandchild.child_runs[1]
assert ggc.name == "great_grandchild_fn"
assert ggc.inputs == {"a": 1, "b": 2}


def test_trace_to_traceable():
Expand Down

0 comments on commit 546a36f

Please sign in to comment.