Skip to content

Commit

Permalink
CVar Propagation in evals (#877)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Jul 17, 2024
1 parent 158b999 commit 4147a60
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 69 deletions.
9 changes: 3 additions & 6 deletions python/langsmith/_expect.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def test_output_semantically_close():
from __future__ import annotations

import atexit
import concurrent.futures
import inspect
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -91,15 +90,13 @@ def __init__(
client: Optional[ls_client.Client],
key: str,
value: Any,
_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None,
_executor: Optional[ls_utils.ContextThreadPoolExecutor] = None,
run_id: Optional[str] = None,
):
self._client = client
self.key = key
self.value = value
self._executor = _executor or concurrent.futures.ThreadPoolExecutor(
max_workers=3
)
self._executor = _executor or ls_utils.ContextThreadPoolExecutor(max_workers=3)
rt = rh.get_current_run_tree()
self._run_id = rt.trace_id if rt else run_id

Expand Down Expand Up @@ -255,7 +252,7 @@ class _Expect:

def __init__(self, *, client: Optional[ls_client.Client] = None):
self._client = client
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
self.executor = ls_utils.ContextThreadPoolExecutor(max_workers=3)
atexit.register(self.executor.shutdown, wait=True)

def embedding_distance(
Expand Down
13 changes: 12 additions & 1 deletion python/langsmith/_internal/_aiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,13 @@ async def process_item(item):

async def process_generator():
tasks = []
accepts_context = asyncio_accepts_context()
async for item in generator:
task = asyncio.create_task(process_item(item))
if accepts_context:
context = contextvars.copy_context()
task = asyncio.create_task(process_item(item), context=context)
else:
task = asyncio.create_task(process_item(item))
tasks.append(task)
if n is not None and len(tasks) >= n:
done, pending = await asyncio.wait(
Expand Down Expand Up @@ -319,3 +324,9 @@ async def aio_to_thread(func, /, *args, **kwargs):
ctx = contextvars.copy_context()
func_call = functools.partial(ctx.run, func, *args, **kwargs)
return await loop.run_in_executor(None, func_call)


@functools.lru_cache(maxsize=1)
def asyncio_accepts_context():
"""Check if the current asyncio event loop accepts a context argument."""
return accepts_context(asyncio.create_task)
3 changes: 1 addition & 2 deletions python/langsmith/_testing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import atexit
import concurrent.futures
import datetime
import functools
import inspect
Expand Down Expand Up @@ -392,7 +391,7 @@ def __init__(
self._experiment = experiment
self._dataset = dataset
self._version: Optional[datetime.datetime] = None
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self._executor = ls_utils.ContextThreadPoolExecutor(max_workers=1)
atexit.register(_end_tests, self)

@property
Expand Down
5 changes: 3 additions & 2 deletions python/langsmith/beta/_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""

import collections
import concurrent.futures
import datetime
import itertools
import uuid
Expand Down Expand Up @@ -218,6 +217,8 @@ def compute_test_metrics(
Returns:
None: This function does not return any value.
"""
from langsmith import ContextThreadPoolExecutor

evaluators_: List[ls_eval.RunEvaluator] = []
for func in evaluators:
if isinstance(func, ls_eval.RunEvaluator):
Expand All @@ -230,7 +231,7 @@ def compute_test_metrics(
)
client = client or Client()
traces = _load_nested_traces(project_name, client)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrency) as executor:
with ContextThreadPoolExecutor(max_workers=max_concurrency) as executor:
results = executor.map(
client.evaluate_run, *zip(*_outer_product(traces, evaluators_))
)
Expand Down
56 changes: 31 additions & 25 deletions python/langsmith/evaluation/_arunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,12 @@ async def _arun_evaluators(
**{"experiment": self.experiment_name},
}
with rh.tracing_context(
**{**current_context, "project_name": "evaluators", "metadata": metadata}
**{
**current_context,
"project_name": "evaluators",
"metadata": metadata,
"enabled": True,
}
):
run = current_results["run"]
example = current_results["example"]
Expand Down Expand Up @@ -676,11 +681,11 @@ async def _aapply_summary_evaluators(
**current_context,
"project_name": "evaluators",
"metadata": metadata,
"enabled": True,
}
):
for evaluator in summary_evaluators:
try:
# TODO: Support async evaluators
summary_eval_result = evaluator(runs, examples)
flattened_results = self.client._select_eval_results(
summary_eval_result,
Expand Down Expand Up @@ -808,30 +813,31 @@ def _get_run(r: run_trees.RunTree) -> None:
nonlocal run
run = r

try:
await fn(
example.inputs,
langsmith_extra=rh.LangSmithExtra(
reference_example_id=example.id,
on_end=_get_run,
project_name=experiment_name,
metadata={
**metadata,
"example_version": (
example.modified_at.isoformat()
if example.modified_at
else example.created_at.isoformat()
),
},
client=client,
),
with rh.tracing_context(enabled=True):
try:
await fn(
example.inputs,
langsmith_extra=rh.LangSmithExtra(
reference_example_id=example.id,
on_end=_get_run,
project_name=experiment_name,
metadata={
**metadata,
"example_version": (
example.modified_at.isoformat()
if example.modified_at
else example.created_at.isoformat()
),
},
client=client,
),
)
except Exception as e:
logger.error(f"Error running target function: {e}")
return _ForwardResults(
run=cast(schemas.Run, run),
example=example,
)
except Exception as e:
logger.error(f"Error running target function: {e}")
return _ForwardResults(
run=cast(schemas.Run, run),
example=example,
)


def _ensure_async_traceable(
Expand Down
69 changes: 40 additions & 29 deletions python/langsmith/evaluation/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,9 @@ def evaluate_and_submit_feedback(
return result

tqdm = _load_tqdm()
with cf.ThreadPoolExecutor(max_workers=max_concurrency or 1) as executor:
with ls_utils.ContextThreadPoolExecutor(
max_workers=max_concurrency or 1
) as executor:
futures = []
for example_id, runs_list in tqdm(runs_dict.items()):
results[example_id] = {
Expand Down Expand Up @@ -1207,7 +1209,7 @@ def _predict(
)

else:
with cf.ThreadPoolExecutor(max_concurrency) as executor:
with ls_utils.ContextThreadPoolExecutor(max_concurrency) as executor:
futures = [
executor.submit(
_forward,
Expand Down Expand Up @@ -1239,7 +1241,12 @@ def _run_evaluators(
},
}
with rh.tracing_context(
**{**current_context, "project_name": "evaluators", "metadata": metadata}
**{
**current_context,
"project_name": "evaluators",
"metadata": metadata,
"enabled": True,
}
):
run = current_results["run"]
example = current_results["example"]
Expand Down Expand Up @@ -1280,10 +1287,13 @@ def _score(
(e.g. from a previous prediction step)
"""
if max_concurrency == 0:
context = copy_context()
for current_results in self.get_results():
yield self._run_evaluators(evaluators, current_results)
yield context.run(self._run_evaluators, evaluators, current_results)
else:
with cf.ThreadPoolExecutor(max_workers=max_concurrency) as executor:
with ls_utils.ContextThreadPoolExecutor(
max_workers=max_concurrency
) as executor:
futures = []
for current_results in self.get_results():
futures.append(
Expand All @@ -1305,7 +1315,7 @@ def _apply_summary_evaluators(
runs.append(run)
examples.append(example)
aggregate_feedback = []
with cf.ThreadPoolExecutor() as executor:
with ls_utils.ContextThreadPoolExecutor() as executor:
project_id = self._get_experiment().id
current_context = rh.get_tracing_context()
metadata = {
Expand Down Expand Up @@ -1447,30 +1457,31 @@ def _get_run(r: run_trees.RunTree) -> None:
nonlocal run
run = r

try:
fn(
example.inputs,
langsmith_extra=rh.LangSmithExtra(
reference_example_id=example.id,
on_end=_get_run,
project_name=experiment_name,
metadata={
**metadata,
"example_version": (
example.modified_at.isoformat()
if example.modified_at
else example.created_at.isoformat()
),
},
client=client,
),
with rh.tracing_context(enabled=True):
try:
fn(
example.inputs,
langsmith_extra=rh.LangSmithExtra(
reference_example_id=example.id,
on_end=_get_run,
project_name=experiment_name,
metadata={
**metadata,
"example_version": (
example.modified_at.isoformat()
if example.modified_at
else example.created_at.isoformat()
),
},
client=client,
),
)
except Exception as e:
logger.error(f"Error running target function: {e}")
return _ForwardResults(
run=cast(schemas.Run, run),
example=example,
)
except Exception as e:
logger.error(f"Error running target function: {e}")
return _ForwardResults(
run=cast(schemas.Run, run),
example=example,
)


def _resolve_data(
Expand Down
4 changes: 2 additions & 2 deletions python/langsmith/run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ async def async_wrapper(
)

try:
accepts_context = aitertools.accepts_context(asyncio.create_task)
accepts_context = aitertools.asyncio_accepts_context()
if func_accepts_parent_run:
kwargs["run_tree"] = run_container["new_run"]
if not func_accepts_config:
Expand Down Expand Up @@ -492,7 +492,7 @@ async def async_generator_wrapper(
kwargs.pop("config", None)
async_gen_result = func(*args, **kwargs)
# Can't iterate through if it's a coroutine
accepts_context = aitertools.accepts_context(asyncio.create_task)
accepts_context = aitertools.asyncio_accepts_context()
if inspect.iscoroutine(async_gen_result):
if accepts_context:
async_gen_result = await asyncio.create_task(
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langsmith"
version = "0.1.88"
version = "0.1.89"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
authors = ["LangChain <support@langchain.dev>"]
license = "MIT"
Expand Down
1 change: 0 additions & 1 deletion python/tests/evaluation/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def predict(inputs: dict) -> dict:
},
num_repetitions=3,
)
results.wait()
assert len(results) == 30
examples = client.list_examples(dataset_name=dataset_name)
for example in examples:
Expand Down

0 comments on commit 4147a60

Please sign in to comment.