From 0afc018ed5b4473b93ca0467b29fe64c908e0e87 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Fri, 19 Jul 2024 18:34:05 -0700 Subject: [PATCH] Async trace context manager (#887) Exposed via the same `trace` CM. Would resolve https://github.com/langchain-ai/langsmith-sdk/issues/882 --- python/langsmith/run_helpers.py | 357 +++++++++++++++----- python/pyproject.toml | 2 +- python/tests/unit_tests/test_run_helpers.py | 34 ++ 3 files changed, 301 insertions(+), 92 deletions(-) diff --git a/python/langsmith/run_helpers.py b/python/langsmith/run_helpers.py index 4afa8e69e..1131400bd 100644 --- a/python/langsmith/run_helpers.py +++ b/python/langsmith/run_helpers.py @@ -43,6 +43,8 @@ from langsmith.env import _runtime_env if TYPE_CHECKING: + from types import TracebackType + from langchain_core.runnables import Runnable LOGGER = logging.getLogger(__name__) @@ -685,6 +687,270 @@ def generator_wrapper( return decorator +class trace: + """Manage a langsmith run in context. + + This class can be used as both a synchronous and asynchronous context manager. + + Parameters: + ----------- + name : str + Name of the run + run_type : ls_client.RUN_TYPE_T, optional + Type of run (e.g., "chain", "llm", "tool"). Defaults to "chain". + inputs : Optional[Dict], optional + Initial input data for the run + project_name : Optional[str], optional + Associates the run with a specific project, overriding defaults + parent : Optional[Union[run_trees.RunTree, str, Mapping]], optional + Parent run, accepts RunTree, dotted order string, or tracing headers + tags : Optional[List[str]], optional + Categorization labels for the run + metadata : Optional[Mapping[str, Any]], optional + Arbitrary key-value pairs for run annotation + client : Optional[ls_client.Client], optional + LangSmith client for specifying a different tenant, + setting custom headers, or modifying API endpoint + run_id : Optional[ls_client.ID_TYPE], optional + Preset identifier for the run + reference_example_id : Optional[ls_client.ID_TYPE], optional + You typically won't set this. It associates this run with a dataset example. + This is only valid for root runs (not children) in an evaluation context. + exceptions_to_handle : Optional[Tuple[Type[BaseException], ...]], optional + Typically not set. Exception types to ignore in what is sent up to LangSmith + extra : Optional[Dict], optional + Typically not set. Use 'metadata' instead. Extra data to be sent to LangSmith. + + Examples: + --------- + Synchronous usage: + >>> with trace("My Operation", run_type="tool", tags=["important"]) as run: + ... result = "foo" # Do some_operation() + ... run.metadata["some-key"] = "some-value" + ... run.end(outputs={"result": result}) + + Asynchronous usage: + >>> async def main(): + ... async with trace("Async Operation", run_type="tool", tags=["async"]) as run: + ... result = "foo" # Can await some_async_operation() + ... run.metadata["some-key"] = "some-value" + ... # "end" just adds the outputs and sets error to None + ... # The actual patching of the run happens when the context exits + ... run.end(outputs={"result": result}) + >>> asyncio.run(main()) + + Allowing pytest.skip in a test: + >>> import sys + >>> import pytest + >>> with trace("OS-Specific Test", exceptions_to_handle=(pytest.skip.Exception,)): + ... if sys.platform == "win32": + ... pytest.skip("Not supported on Windows") + ... result = "foo" # e.g., do some unix_specific_operation() + """ + + def __init__( + self, + name: str, + run_type: ls_client.RUN_TYPE_T = "chain", + *, + inputs: Optional[Dict] = None, + extra: Optional[Dict] = None, + project_name: Optional[str] = None, + parent: Optional[Union[run_trees.RunTree, str, Mapping]] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Mapping[str, Any]] = None, + client: Optional[ls_client.Client] = None, + run_id: Optional[ls_client.ID_TYPE] = None, + reference_example_id: Optional[ls_client.ID_TYPE] = None, + exceptions_to_handle: Optional[Tuple[Type[BaseException], ...]] = None, + **kwargs: Any, + ): + """Initialize the trace context manager. + + Warns if unsupported kwargs are passed. + """ + if kwargs: + warnings.warn( + "The `trace` context manager no longer supports the following kwargs: " + f"{sorted(kwargs.keys())}.", + DeprecationWarning, + ) + self.name = name + self.run_type = run_type + self.inputs = inputs + self.extra = extra + self.project_name = project_name + self.parent = parent + # The run tree is deprecated. Keeping for backwards compat. + # Will fully merge within parent later. + self.run_tree = kwargs.get("run_tree") + self.tags = tags + self.metadata = metadata + self.client = client + self.run_id = run_id + self.reference_example_id = reference_example_id + self.exceptions_to_handle = exceptions_to_handle + self.new_run: Optional[run_trees.RunTree] = None + self.old_ctx: Optional[dict] = None + + def _setup(self) -> run_trees.RunTree: + """Set up the tracing context and create a new run. + + This method initializes the tracing context, merges tags and metadata, + creates a new run (either as a child of an existing run or as a new root run), + and sets up the necessary context variables. + + Returns: + run_trees.RunTree: The newly created run. + """ + self.old_ctx = get_tracing_context() + is_disabled = self.old_ctx.get("enabled", True) is False + outer_tags = _TAGS.get() + outer_metadata = _METADATA.get() + parent_run_ = _get_parent_run( + { + "parent": self.parent, + "run_tree": self.run_tree, + "client": self.client, + } + ) + + tags_ = sorted(set((self.tags or []) + (outer_tags or []))) + metadata = { + **(self.metadata or {}), + **(outer_metadata or {}), + "ls_method": "trace", + } + + extra_outer = self.extra or {} + extra_outer["metadata"] = metadata + + project_name_ = _get_project_name(self.project_name) + + if parent_run_ is not None and not is_disabled: + self.new_run = parent_run_.create_child( + name=self.name, + run_id=self.run_id, + run_type=self.run_type, + extra=extra_outer, + inputs=self.inputs, + tags=tags_, + ) + else: + self.new_run = run_trees.RunTree( + name=self.name, + id=ls_client._ensure_uuid(self.run_id), + reference_example_id=ls_client._ensure_uuid( + self.reference_example_id, accept_null=True + ), + run_type=self.run_type, + extra=extra_outer, + project_name=project_name_ or "default", + inputs=self.inputs or {}, + tags=tags_, + client=self.client, # type: ignore[arg-type] + ) + + if not is_disabled: + self.new_run.post() + _TAGS.set(tags_) + _METADATA.set(metadata) + _PARENT_RUN_TREE.set(self.new_run) + _PROJECT_NAME.set(project_name_) + + return self.new_run + + def _teardown( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + """Clean up the tracing context and finalize the run. + + This method handles exceptions, ends the run if necessary, + patches the run if it's not disabled, and resets the tracing context. + + Args: + exc_type: The type of the exception that occurred, if any. + exc_value: The exception instance that occurred, if any. + traceback: The traceback object associated with the exception, if any. + """ + if self.new_run is None: + warnings.warn("Tracing context was not set up properly.", RuntimeWarning) + return + if exc_type is not None: + if self.exceptions_to_handle and issubclass( + exc_type, self.exceptions_to_handle + ): + tb = None + else: + tb = utils._format_exc() + tb = f"{exc_type.__name__}: {exc_value}\n\n{tb}" + self.new_run.end(error=tb) + if self.old_ctx is not None: + is_disabled = self.old_ctx.get("enabled", True) is False + if not is_disabled: + self.new_run.patch() + + _set_tracing_context(self.old_ctx) + else: + warnings.warn("Tracing context was not set up properly.", RuntimeWarning) + + def __enter__(self) -> run_trees.RunTree: + """Enter the context manager synchronously. + + Returns: + run_trees.RunTree: The newly created run. + """ + return self._setup() + + def __exit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + """Exit the context manager synchronously. + + Args: + exc_type: The type of the exception that occurred, if any. + exc_value: The exception instance that occurred, if any. + traceback: The traceback object associated with the exception, if any. + """ + self._teardown(exc_type, exc_value, traceback) + + async def __aenter__(self) -> run_trees.RunTree: + """Enter the context manager asynchronously. + + Returns: + run_trees.RunTree: The newly created run. + """ + return await aitertools.aio_to_thread(self._setup) + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + """Exit the context manager asynchronously. + + Args: + exc_type: The type of the exception that occurred, if any. + exc_value: The exception instance that occurred, if any. + traceback: The traceback object associated with the exception, if any. + """ + if exc_type is not None: + await asyncio.shield( + aitertools.aio_to_thread(self._teardown, exc_type, exc_value, traceback) + ) + else: + await aitertools.aio_to_thread( + self._teardown, exc_type, exc_value, traceback + ) + + def _get_project_name(project_name: Optional[str]) -> Optional[str]: prt = _PARENT_RUN_TREE.get() return ( @@ -698,97 +964,6 @@ def _get_project_name(project_name: Optional[str]) -> Optional[str]: ) -@contextlib.contextmanager -def trace( - name: str, - run_type: ls_client.RUN_TYPE_T = "chain", - *, - inputs: Optional[Dict] = None, - extra: Optional[Dict] = None, - project_name: Optional[str] = None, - parent: Optional[Union[run_trees.RunTree, str, Mapping]] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Mapping[str, Any]] = None, - client: Optional[ls_client.Client] = None, - run_id: Optional[ls_client.ID_TYPE] = None, - reference_example_id: Optional[ls_client.ID_TYPE] = None, - exceptions_to_handle: Optional[Tuple[Type[BaseException], ...]] = None, - **kwargs: Any, -) -> Generator[run_trees.RunTree, None, None]: - """Context manager for creating a run tree.""" - if kwargs: - # In case someone was passing an executor before. - warnings.warn( - "The `trace` context manager no longer supports the following kwargs: " - f"{sorted(kwargs.keys())}.", - DeprecationWarning, - ) - old_ctx = get_tracing_context() - is_disabled = old_ctx.get("enabled", True) is False - outer_tags = _TAGS.get() - outer_metadata = _METADATA.get() - parent_run_ = _get_parent_run( - {"parent": parent, "run_tree": kwargs.get("run_tree"), "client": client} - ) - - # Merge context variables - tags_ = sorted(set((tags or []) + (outer_tags or []))) - metadata = {**(metadata or {}), **(outer_metadata or {}), "ls_method": "trace"} - - extra_outer = extra or {} - extra_outer["metadata"] = metadata - - project_name_ = _get_project_name(project_name) - # If it's disabled, we break the tree - if parent_run_ is not None and not is_disabled: - new_run = parent_run_.create_child( - name=name, - run_id=run_id, - run_type=run_type, - extra=extra_outer, - inputs=inputs, - tags=tags_, - ) - else: - new_run = run_trees.RunTree( - name=name, - id=ls_client._ensure_uuid(run_id), - reference_example_id=ls_client._ensure_uuid( - reference_example_id, accept_null=True - ), - run_type=run_type, - extra=extra_outer, - project_name=project_name_, # type: ignore[arg-type] - inputs=inputs or {}, - tags=tags_, - client=client, # type: ignore[arg-type] - ) - if not is_disabled: - new_run.post() - _TAGS.set(tags_) - _METADATA.set(metadata) - _PARENT_RUN_TREE.set(new_run) - _PROJECT_NAME.set(project_name_) - - try: - yield new_run - except (Exception, KeyboardInterrupt, BaseException) as e: - if exceptions_to_handle and isinstance(e, exceptions_to_handle): - tb = None - else: - tb = utils._format_exc() - tb = f"{e.__class__.__name__}: {e}\n\n{tb}" - new_run.end(error=tb) - if not is_disabled: - new_run.patch() - raise e - finally: - # Reset the old context - _set_tracing_context(old_ctx) - if not is_disabled: - new_run.patch() - - def as_runnable(traceable_fn: Callable) -> Runnable: """Convert a function wrapped by the LangSmith @traceable decorator to a Runnable. diff --git a/python/pyproject.toml b/python/pyproject.toml index 3ed59d26f..f6f9fa609 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langsmith" -version = "0.1.92" +version = "0.1.93" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." authors = ["LangChain "] license = "MIT" diff --git a/python/tests/unit_tests/test_run_helpers.py b/python/tests/unit_tests/test_run_helpers.py index 434c10bca..4bbc182c9 100644 --- a/python/tests/unit_tests/test_run_helpers.py +++ b/python/tests/unit_tests/test_run_helpers.py @@ -961,6 +961,40 @@ def _get_run(r: RunTree) -> None: assert child_runs[0].inputs == {"a": 1, "b": 2} +async def test_traceable_to_atrace(): + @traceable + async def parent_fn(a: int, b: int) -> int: + async with langsmith.trace( + name="child_fn", inputs={"a": a, "b": b} + ) as run_tree: + result = a + b + run_tree.end(outputs={"result": result}) + return result + + run: Optional[RunTree] = None # type: ignore + + def _get_run(r: RunTree) -> None: + nonlocal run + run = r + + with tracing_context(enabled=True): + result = await parent_fn( + 1, 2, langsmith_extra={"on_end": _get_run, "client": _get_mock_client()} + ) + + assert result == 3 + assert run is not None + run = cast(RunTree, run) + assert run.name == "parent_fn" + assert run.outputs == {"output": 3} + assert run.inputs == {"a": 1, "b": 2} + child_runs = run.child_runs + assert child_runs + assert len(child_runs) == 1 + assert child_runs[0].name == "child_fn" + assert child_runs[0].inputs == {"a": 1, "b": 2} + + def test_trace_to_traceable(): @traceable def child_fn(a: int, b: int) -> int: