Skip to content

Commit

Permalink
Add revision identifier to run_on_dataset (langchain-ai#16167)
Browse files Browse the repository at this point in the history
Allow specifying revision identifier for better project versioning
  • Loading branch information
samnoyes authored Jan 18, 2024
1 parent 5d8c147 commit 7d44472
Showing 1 changed file with 54 additions and 10 deletions.
64 changes: 54 additions & 10 deletions libs/langchain/langchain/smith/evaluation/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,7 @@ async def _arun_llm(
tags: Optional[List[str]] = None,
callbacks: Callbacks = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Union[str, BaseMessage]:
"""Asynchronously run the language model.
Expand All @@ -682,7 +683,9 @@ async def _arun_llm(
):
return await llm.ainvoke(
prompt_or_messages,
config=RunnableConfig(callbacks=callbacks, tags=tags or []),
config=RunnableConfig(
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
),
)
else:
raise InputFormatError(
Expand All @@ -695,12 +698,18 @@ async def _arun_llm(
try:
prompt = _get_prompt(inputs)
llm_output: Union[str, BaseMessage] = await llm.ainvoke(
prompt, config=RunnableConfig(callbacks=callbacks, tags=tags or [])
prompt,
config=RunnableConfig(
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
),
)
except InputFormatError:
messages = _get_messages(inputs)
llm_output = await llm.ainvoke(
messages, config=RunnableConfig(callbacks=callbacks, tags=tags or [])
messages,
config=RunnableConfig(
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
),
)
return llm_output

Expand All @@ -712,6 +721,7 @@ async def _arun_chain(
*,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Union[dict, str]:
"""Run a chain asynchronously on inputs."""
inputs_ = inputs if input_mapper is None else input_mapper(inputs)
Expand All @@ -723,10 +733,15 @@ async def _arun_chain(
):
val = next(iter(inputs_.values()))
output = await chain.ainvoke(
val, config=RunnableConfig(callbacks=callbacks, tags=tags or [])
val,
config=RunnableConfig(
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
),
)
else:
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
runnable_config = RunnableConfig(
tags=tags or [], callbacks=callbacks, metadata=metadata or {}
)
output = await chain.ainvoke(inputs_, config=runnable_config)
return output

Expand Down Expand Up @@ -762,6 +777,7 @@ async def _arun_llm_or_chain(
tags=config["tags"],
callbacks=config["callbacks"],
input_mapper=input_mapper,
metadata=config.get("metadata"),
)
else:
chain = llm_or_chain_factory()
Expand All @@ -771,6 +787,7 @@ async def _arun_llm_or_chain(
tags=config["tags"],
callbacks=config["callbacks"],
input_mapper=input_mapper,
metadata=config.get("metadata"),
)
result = output
except Exception as e:
Expand All @@ -793,6 +810,7 @@ def _run_llm(
*,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Union[str, BaseMessage]:
"""
Run the language model on the example.
Expand All @@ -819,7 +837,9 @@ def _run_llm(
):
llm_output: Union[str, BaseMessage] = llm.invoke(
prompt_or_messages,
config=RunnableConfig(callbacks=callbacks, tags=tags or []),
config=RunnableConfig(
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
),
)
else:
raise InputFormatError(
Expand All @@ -831,12 +851,16 @@ def _run_llm(
try:
llm_prompts = _get_prompt(inputs)
llm_output = llm.invoke(
llm_prompts, config=RunnableConfig(callbacks=callbacks, tags=tags or [])
llm_prompts,
config=RunnableConfig(
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
),
)
except InputFormatError:
llm_messages = _get_messages(inputs)
llm_output = llm.invoke(
llm_messages, config=RunnableConfig(callbacks=callbacks)
llm_messages,
config=RunnableConfig(callbacks=callbacks, metadata=metadata or {}),
)
return llm_output

Expand All @@ -848,6 +872,7 @@ def _run_chain(
*,
tags: Optional[List[str]] = None,
input_mapper: Optional[Callable[[Dict], Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Union[Dict, str]:
"""Run a chain on inputs."""
inputs_ = inputs if input_mapper is None else input_mapper(inputs)
Expand All @@ -859,10 +884,15 @@ def _run_chain(
):
val = next(iter(inputs_.values()))
output = chain.invoke(
val, config=RunnableConfig(callbacks=callbacks, tags=tags or [])
val,
config=RunnableConfig(
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
),
)
else:
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks)
runnable_config = RunnableConfig(
tags=tags or [], callbacks=callbacks, metadata=metadata or {}
)
output = chain.invoke(inputs_, config=runnable_config)
return output

Expand Down Expand Up @@ -899,6 +929,7 @@ def _run_llm_or_chain(
config["callbacks"],
tags=config["tags"],
input_mapper=input_mapper,
metadata=config.get("metadata"),
)
else:
chain = llm_or_chain_factory()
Expand All @@ -908,6 +939,7 @@ def _run_llm_or_chain(
config["callbacks"],
tags=config["tags"],
input_mapper=input_mapper,
metadata=config.get("metadata"),
)
result = output
except Exception as e:
Expand Down Expand Up @@ -1083,8 +1115,13 @@ def prepare(
input_mapper: Optional[Callable[[Dict], Any]] = None,
concurrency_level: int = 5,
project_metadata: Optional[Dict[str, Any]] = None,
revision_id: Optional[str] = None,
) -> _DatasetRunContainer:
project_name = project_name or name_generation.random_name()
if revision_id:
if not project_metadata:
project_metadata = {}
project_metadata.update({"revision_id": revision_id})
wrapped_model, project, dataset, examples = _prepare_eval_run(
client,
dataset_name,
Expand Down Expand Up @@ -1121,6 +1158,7 @@ def prepare(
],
tags=tags,
max_concurrency=concurrency_level,
metadata={"revision_id": revision_id} if revision_id else {},
)
for example in examples
]
Expand Down Expand Up @@ -1183,6 +1221,7 @@ async def arun_on_dataset(
project_metadata: Optional[Dict[str, Any]] = None,
verbose: bool = False,
tags: Optional[List[str]] = None,
revision_id: Optional[str] = None,
**kwargs: Any,
) -> Dict[str, Any]:
input_mapper = kwargs.pop("input_mapper", None)
Expand All @@ -1208,6 +1247,7 @@ async def arun_on_dataset(
input_mapper,
concurrency_level,
project_metadata=project_metadata,
revision_id=revision_id,
)
batch_results = await runnable_utils.gather_with_concurrency(
container.configs[0].get("max_concurrency"),
Expand Down Expand Up @@ -1235,6 +1275,7 @@ def run_on_dataset(
project_metadata: Optional[Dict[str, Any]] = None,
verbose: bool = False,
tags: Optional[List[str]] = None,
revision_id: Optional[str] = None,
**kwargs: Any,
) -> Dict[str, Any]:
input_mapper = kwargs.pop("input_mapper", None)
Expand All @@ -1260,6 +1301,7 @@ def run_on_dataset(
input_mapper,
concurrency_level,
project_metadata=project_metadata,
revision_id=revision_id,
)
if concurrency_level == 0:
batch_results = [
Expand Down Expand Up @@ -1309,6 +1351,8 @@ def run_on_dataset(
log feedback and run traces.
verbose: Whether to print progress.
tags: Tags to add to each run in the project.
revision_id: Optional revision identifier to assign this test run to
track the performance of different versions of your system.
Returns:
A dictionary containing the run's project name and the resulting model outputs.
Expand Down

0 comments on commit 7d44472

Please sign in to comment.