Skip to content

Commit

Permalink
Merge branch 'main' into fix-anthropic-kevin
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Dec 17, 2024
2 parents 67f7e58 + e32417e commit 33ed5a2
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 12 deletions.
1 change: 1 addition & 0 deletions letta/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

# embeddings
MAX_EMBEDDING_DIM = 4096 # maximum supported embeding size - do NOT change or else DBs will need to be reset
DEFAULT_EMBEDDING_CHUNK_SIZE = 300

# tokenizers
EMBEDDING_TO_TOKENIZER_MAP = {
Expand Down
2 changes: 2 additions & 0 deletions letta/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@


class Provider(BaseModel):
name: str = Field(..., description="The name of the provider")

def list_llm_models(self) -> List[LLMConfig]:
return []
Expand Down Expand Up @@ -465,6 +466,7 @@ def list_embedding_models(self) -> List[EmbeddingConfig]:

class GoogleAIProvider(Provider):
# gemini
name: str = "google_ai"
api_key: str = Field(..., description="API key for the Google AI API.")
base_url: str = "https://generativelanguage.googleapis.com"

Expand Down
35 changes: 35 additions & 0 deletions letta/schemas/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pydantic import BaseModel, Field, field_validator

from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
from letta.schemas.block import CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.letta_base import OrmMetadataBase
Expand Down Expand Up @@ -107,6 +108,16 @@ class CreateAgent(BaseModel, validate_assignment=True): #
include_base_tools: bool = Field(True, description="The LLM configuration used by the agent.")
description: Optional[str] = Field(None, description="The description of the agent.")
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
llm: Optional[str] = Field(
None,
description="The LLM configuration handle used by the agent, specified in the format "
"provider/model-name, as an alternative to specifying llm_config.",
)
embedding: Optional[str] = Field(
None, description="The embedding configuration handle used by the agent, specified in the format provider/model-name."
)
context_window_limit: Optional[int] = Field(None, description="The context window limit used by the agent.")
embedding_chunk_size: Optional[int] = Field(DEFAULT_EMBEDDING_CHUNK_SIZE, description="The embedding chunk size used by the agent.")

@field_validator("name")
@classmethod
Expand All @@ -133,6 +144,30 @@ def validate_name(cls, name: str) -> str:

return name

@field_validator("llm")
@classmethod
def validate_llm(cls, llm: Optional[str]) -> Optional[str]:
if not llm:
return llm

provider_name, model_name = llm.split("/", 1)
if not provider_name or not model_name:
raise ValueError("The llm config handle should be in the format provider/model-name")

return llm

@field_validator("embedding")
@classmethod
def validate_embedding(cls, embedding: Optional[str]) -> Optional[str]:
if not embedding:
return embedding

provider_name, model_name = embedding.split("/", 1)
if not provider_name or not model_name:
raise ValueError("The embedding config handle should be in the format provider/model-name")

return embedding


class UpdateAgent(BaseModel):
name: Optional[str] = Field(None, description="The name of the agent.")
Expand Down
4 changes: 3 additions & 1 deletion letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,7 +1309,9 @@ def get_llm_config_from_handle(self, handle: str, context_window_limit: Optional

if context_window_limit:
if context_window_limit > llm_config.context_window:
raise ValueError(f"Context window limit ({context_window_limit}) is greater than maximum of ({llm_config.context_window})")
raise ValueError(
f"Context window limit ({context_window_limit}) is greater than maximum of ({llm_config.context_window})"
)
llm_config.context_window = context_window_limit

return llm_config
Expand Down
3 changes: 3 additions & 0 deletions letta/services/agent_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def create_agent(
) -> PydanticAgentState:
system = derive_system_message(agent_type=agent_create.agent_type, system=agent_create.system)

if not agent_create.llm_config or not agent_create.embedding_config:
raise ValueError("llm_config and embedding_config are required")

# create blocks (note: cannot be linked into the agent_id is created)
block_ids = list(agent_create.block_ids or []) # Create a local copy to avoid modifying the original
for create_block in agent_create.memory_blocks:
Expand Down
21 changes: 10 additions & 11 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from letta.schemas.agent import CreateAgent, UpdateAgent
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.job import Job as PydanticJob
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.source import Source as PydanticSource
from letta.server.server import SyncServer
Expand Down Expand Up @@ -329,8 +328,8 @@ def agent_id(server, user_id, base_tools):
name="test_agent",
tool_ids=[t.id for t in base_tools],
memory_blocks=[],
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
llm="openai/gpt-4",
embedding="openai/text-embedding-ada-002",
),
actor=actor,
)
Expand All @@ -350,8 +349,8 @@ def other_agent_id(server, user_id, base_tools):
name="test_agent_other",
tool_ids=[t.id for t in base_tools],
memory_blocks=[],
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
llm="openai/gpt-4",
embedding="openai/text-embedding-ada-002",
),
actor=actor,
)
Expand Down Expand Up @@ -618,8 +617,8 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str):
request=CreateAgent(
name="nonexistent_tools_agent",
memory_blocks=[],
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
llm="openai/gpt-4",
embedding="openai/text-embedding-ada-002",
),
actor=server.user_manager.get_user_or_default(user_id),
)
Expand Down Expand Up @@ -904,8 +903,8 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools
CreateBlock(label="human", value="The human's name is Bob."),
CreateBlock(label="persona", value="My name is Alice."),
],
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
llm="openai/gpt-4",
embedding="openai/text-embedding-ada-002",
),
actor=actor,
)
Expand Down Expand Up @@ -1091,8 +1090,8 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to
CreateBlock(label="human", value="The human's name is Bob."),
CreateBlock(label="persona", value="My name is Alice."),
],
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
llm="openai/gpt-4",
embedding="openai/text-embedding-ada-002",
include_base_tools=False,
),
actor=actor,
Expand Down

0 comments on commit 33ed5a2

Please sign in to comment.