Skip to content

Commit

Permalink
Add extraction task (#29)
Browse files Browse the repository at this point in the history
Add extraction task
  • Loading branch information
eyurtsev authored Nov 20, 2023
1 parent 5f2ce54 commit f79d797
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 36 deletions.
Empty file.
75 changes: 75 additions & 0 deletions langchain_benchmarks/extraction/email_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from enum import Enum
from typing import Optional, List

from langchain.smith import RunEvalConfig
from pydantic import BaseModel, Field

from langchain_benchmarks.schema import ExtractionTask


class ToneEnum(str, Enum):
"""The tone of the email."""

positive = "positive"
negative = "negative"


class Email(BaseModel):
"""Relevant information about an email."""

sender: Optional[str] = Field(None, description="The sender's name, if available")
sender_phone_number: Optional[str] = Field(
None, description="The sender's phone number, if available"
)
sender_address: Optional[str] = Field(
None, description="The sender's address, if available"
)
action_items: List[str] = Field(
..., description="A list of action items requested by the email"
)
topic: str = Field(
..., description="High level description of what the email is about"
)
tone: ToneEnum = Field(..., description="The tone of the email.")


def get_eval_config(eval_llm: BaseModel) -> RunEvalConfig:
"""Get the evaluation configuration for the email task."""
return RunEvalConfig(
evaluators=[
RunEvalConfig.LabeledScoreString(
criteria={
"accuracy": """
Score 1: The answer is incorrect and unrelated to the question or reference document.
Score 3: The answer is partially correct but has more than one omission or major errors.
Score 5: The answer is mostly correct but has more than one omission or major error.
Score 7: The answer is mostly correct but has at most one omission or major error.
Score 9: The answer is mostly correct with no omissions and only minor errors, and aligns with the reference document.
Score 10: The answer is correct, complete, and aligns with the reference document. Extra information is acceptable if it is sensible.
If the reference answer contains multiple alternatives, the predicted answer must only match one of the alternatives to be considered correct.
If the predicted answer contains additional helpful and accurate information that is not present in the reference answer, it should still be considered correct and not be penalized.
""" # noqa
},
llm=eval_llm,
normalize_by=10.0,
),
],
)


EmailTask = ExtractionTask(
id=4, # To be deprecated
name="Email Extraction",
dataset_id="https://smith.langchain.com/public/36bdfe7d-3cd1-4b36-b957-d12d95810a2b/d",
model=Email,
description="""\
A dataset of 42 real emails deduped from a spam folder, with semantic HTML tags removed, \
as well as a script for initial extraction and formatting of other emails from \
an arbitrary .mbox file like the one exported by Gmail.
Some additional cleanup of the data was done by hand after the initial pass.
See https://github.com/jacoblee93/oss-model-extraction-evals.
""",
)
18 changes: 10 additions & 8 deletions langchain_benchmarks/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from tabulate import tabulate

from langchain_benchmarks.schema import Task
from langchain_benchmarks.schema import ToolUsageTask, ExtractionTask
from langchain_benchmarks.extraction import email_task
from langchain_benchmarks.tool_usage.environments import (
relational_data,
type_writer,
Expand All @@ -15,9 +16,9 @@

@dataclasses.dataclass(frozen=True)
class Registry:
tasks: Sequence[Task]
tasks: Sequence[ToolUsageTask]

def get_task(self, name_or_id: Union[int, str]) -> Task:
def get_task(self, name_or_id: Union[int, str]) -> ToolUsageTask:
"""Get the environment with the given name."""
for env in self.tasks:
if env.name == name_or_id or env.id == name_or_id:
Expand Down Expand Up @@ -58,7 +59,7 @@ def _repr_html_(self) -> str:
]
return tabulate(table, headers=headers, tablefmt="html")

def __getitem__(self, key: Union[int, str]) -> Task:
def __getitem__(self, key: Union[int, str]) -> ToolUsageTask:
"""Get an environment from the registry."""
if isinstance(key, slice):
raise NotImplementedError("Slicing is not supported.")
Expand All @@ -72,7 +73,7 @@ def __getitem__(self, key: Union[int, str]) -> Task:
# Using lower case naming to make a bit prettier API when used in a notebook
registry = Registry(
tasks=[
Task(
ToolUsageTask(
id=0,
name="Tool Usage - Relational Data",
dataset_id=relational_data.DATASET_ID,
Expand Down Expand Up @@ -103,7 +104,7 @@ def __getitem__(self, key: Union[int, str]) -> Task:
"""
),
),
Task(
ToolUsageTask(
id=1,
name="Tool Usage - Typewriter (1 func)",
dataset_id="placeholder",
Expand Down Expand Up @@ -131,7 +132,7 @@ def __getitem__(self, key: Union[int, str]) -> Task:
"""
),
),
Task(
ToolUsageTask(
id=2,
name="Tool Usage - Typewriter",
dataset_id="placeholder",
Expand Down Expand Up @@ -161,7 +162,7 @@ def __getitem__(self, key: Union[int, str]) -> Task:
"""
),
),
Task(
ToolUsageTask(
id=3,
name="Multiverse Math",
dataset_id="placeholder",
Expand All @@ -187,5 +188,6 @@ def __getitem__(self, key: Union[int, str]) -> Task:
"""
),
),
email_task.EmailTask,
]
)
41 changes: 27 additions & 14 deletions langchain_benchmarks/schema.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Schema for the Langchain Benchmarks."""
import dataclasses
from typing import List, Callable, Any, Optional
from typing import List, Callable, Any, Optional, Type

from langchain.tools import BaseTool
from pydantic import BaseModel
from tabulate import tabulate


@dataclasses.dataclass(frozen=True)
class Environment:
class ToolUsageEnvironment:
"""An instance of an environment for tool usage."""

tools: List[BaseTool]
Expand All @@ -18,8 +19,8 @@ class Environment:


@dataclasses.dataclass(frozen=True)
class Task:
"""A definition for a task."""
class BaseTask:
"""A definition of a task."""

id: int
"""The ID of the environment."""
Expand All @@ -28,35 +29,47 @@ class Task:

dataset_id: str
"""The ID of the langsmith public dataset.
This dataset contains expected inputs/outputs for the environment, and
can be used to evaluate the performance of a model/agent etc.
"""

create_environment: Callable[
[], Environment
] # Specialized for tool usage; refactor potentially
"""Factory that returns an environment."""

description: str
"""Description of the task for a data science practitioner.
This can contain information about the task, the dataset, the tools available
etc.
"""

instructions: str
"""Instructions for the agent/chain/llm."""

def _repr_html_(self) -> str:
"""Return an HTML representation of the environment."""
table = [
["ID", self.id],
["Name", self.name],
["Type", self.__class__.__name__],
["Dataset ID", self.dataset_id],
["Description", self.description[:100] + "..."],
]
return tabulate(
table,
tablefmt="html",
)


@dataclasses.dataclass(frozen=True)
class ToolUsageTask(BaseTask):
"""A definition for a task."""

create_environment: Callable[[], ToolUsageEnvironment]
"""Factory that returns an environment."""

instructions: str
"""Instructions for the agent/chain/llm."""


@dataclasses.dataclass(frozen=True)
class ExtractionTask(BaseTask):
"""A definition for an extraction task."""

model: Type[BaseModel]
"""Get the model for the task."""
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from langchain.tools import tool, BaseTool

from langchain_benchmarks.schema import Environment
from langchain_benchmarks.schema import ToolUsageEnvironment


def multiply(a: float, b: float) -> float:
Expand Down Expand Up @@ -76,13 +76,13 @@ def negate(a: float) -> float:
# PUBLIC API


def get_environment() -> Environment:
def get_environment() -> ToolUsageEnvironment:
"""Create an environment."""
tools = cast(
List[BaseTool],
[tool(func) for func in [multiply, add, divide, subtract, power, log, negate]],
)
return Environment(
return ToolUsageEnvironment(
tools=tools,
read_state=None,
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from langchain.tools import BaseTool, tool

from langchain_benchmarks.schema import Environment
from langchain_benchmarks.schema import ToolUsageEnvironment

USER_DATA = [
# IDs are not consecutive to prevent agents from guessing the ID
Expand Down Expand Up @@ -397,9 +397,9 @@ def get_tools() -> List[BaseTool]:
return [tool(f) for f in functions]


def get_environment() -> Environment:
def get_environment() -> ToolUsageEnvironment:
"""Create an environment."""
return Environment(
return ToolUsageEnvironment(
tools=get_tools(),
read_state=None,
)
Expand Down
8 changes: 3 additions & 5 deletions langchain_benchmarks/tool_usage/environments/type_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from langchain.tools import BaseTool, tool

from langchain_benchmarks.schema import Environment
from langchain_benchmarks.schema import ToolUsageEnvironment


@dataclasses.dataclass
Expand All @@ -32,7 +32,7 @@ def type_letter(letter: str) -> str:
# PUBLIC API


def get_environment() -> Environment:
def get_environment() -> ToolUsageEnvironment:
"""Create tools and state reader.
Attention: this is a factory function, so it will create a new environment
Expand All @@ -42,16 +42,14 @@ def get_environment() -> Environment:
A tuple of (tools, state_reader).
"""
paper = Paper(content="") # Start with an empty piece of paper
# functions = _get_available_functions(paper)

def _read_state() -> Any:
"""Read the state of the environment."""
return paper.content

# tools = cast(List[BaseTool], [tool(f) for f in functions])
tools = cast(List[BaseTool], [tool(function(paper))])

return Environment(
return ToolUsageEnvironment(
tools=tools,
read_state=_read_state,
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from langchain.tools import BaseTool, tool

from langchain_benchmarks.schema import Environment
from langchain_benchmarks.schema import ToolUsageEnvironment


@dataclasses.dataclass
Expand Down Expand Up @@ -40,7 +40,7 @@ def _get_available_functions(paper: Paper) -> List[Callable]:
# PUBLIC API


def get_environment() -> Environment:
def get_environment() -> ToolUsageEnvironment:
"""Create tools and state reader.
Attention: this is a factory function, so it will create a new environment
Expand All @@ -58,7 +58,7 @@ def _read_state() -> Any:

tools = cast(List[BaseTool], [tool(f) for f in functions])

return Environment(
return ToolUsageEnvironment(
tools=tools,
read_state=_read_state,
)
Empty file.
3 changes: 3 additions & 0 deletions tests/unit_tests/extraction/test_email_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def test_email_extraction() -> None:
"""Try to import the email task."""
from langchain_benchmarks.extraction import email_task # noqa: F401

0 comments on commit f79d797

Please sign in to comment.