Skip to content

Commit

Permalink
兼容PYDANTIC_V2
Browse files Browse the repository at this point in the history
  • Loading branch information
glide-the committed Jul 17, 2024
1 parent 78074eb commit ce72ee3
Show file tree
Hide file tree
Showing 11 changed files with 40 additions and 33 deletions.
1 change: 1 addition & 0 deletions .env
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring
PYTHONIOENCODING=utf-8
11 changes: 5 additions & 6 deletions langchain_glm/agents/zhipuai_all_tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from langchain_core.runnables.base import RunnableBindingBase
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic.v1 import BaseModel, Field, validator
from pydantic.v1 import Field, validator
from typing_extensions import ClassVar
from zhipuai.core import PYDANTIC_V2, BaseModel, ConfigDict

from langchain_glm.agent_toolkits.all_tools.registry import (
TOOL_STRUCT_TYPE_TO_TOOL_CLASS,
Expand Down Expand Up @@ -143,8 +145,6 @@ class ZhipuAIAllToolsRunnable(RunnableSerializable[Dict, OutputType]):
"""工具模型"""
callback: AgentExecutorAsyncIteratorCallbackHandler
"""ZhipuAI AgentExecutor callback."""
check_every_ms: float = 1_000.0
"""Frequency with which to check run progress in ms."""
intermediate_steps: List[Tuple[AgentAction, BaseToolOutput]] = []
"""intermediate_steps to store the data to be processed."""
history: List[Union[List, Tuple, Dict]] = []
Expand All @@ -153,9 +153,8 @@ class ZhipuAIAllToolsRunnable(RunnableSerializable[Dict, OutputType]):
class Config:
arbitrary_types_allowed = True

@validator("intermediate_steps", pre=True, each_item=True, allow_reuse=True)
def check_intermediate_steps(cls, v):
return v
if PYDANTIC_V2:
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)

@staticmethod
def paser_all_tools(
Expand Down
12 changes: 8 additions & 4 deletions langchain_glm/agents/zhipuai_all_tools/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from numbers import Number
from typing import Any, Dict, List, Optional, Union

from typing_extensions import Self
from zhipuai.core import BaseModel
from typing_extensions import ClassVar, Self
from zhipuai.core import PYDANTIC_V2, BaseModel, ConfigDict


class MsgType:
Expand All @@ -18,8 +18,12 @@ class MsgType:


class AllToolsBaseComponent(BaseModel):
class Config:
arbitrary_types_allowed = True
if PYDANTIC_V2:
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
else:

class Config:
arbitrary_types_allowed = True

@classmethod
@abstractmethod
Expand Down
12 changes: 8 additions & 4 deletions langchain_glm/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,18 @@
)
from langchain_core.utils.json import parse_partial_json
from langchain_core.utils.utils import build_extra_kwargs
from typing_extensions import ClassVar
from zhipuai.core import PYDANTIC_V2, ConfigDict

from langchain_glm.chat_models.all_tools_message import (
ALLToolsMessageChunk,
_paser_chunk,
)

if TYPE_CHECKING:
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.tools import BaseTool
from zhipuai.core import BaseModel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -378,10 +380,12 @@ def is_lc_serializable(cls) -> bool:
http_client: Union[Any, None] = None
"""Optional httpx.Client."""

class Config:
"""Configuration for this pydantic object."""
if PYDANTIC_V2:
model_config: ClassVar[ConfigDict] = ConfigDict(populate_by_name=True)
else:

allow_population_by_field_name = True
class Config:
allow_population_by_field_name = True

@root_validator(pre=True, allow_reuse=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
16 changes: 9 additions & 7 deletions langchain_glm/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,9 @@
cast,
)

import numpy as np
import tiktoken
import zhipuai
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import (
BaseModel,
Extra,
Field,
SecretStr,
Expand All @@ -35,6 +32,8 @@
get_from_dict_or_env,
get_pydantic_field_names,
)
from typing_extensions import ClassVar
from zhipuai.core import PYDANTIC_V2, BaseModel, ConfigDict

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -85,11 +84,14 @@ class ZhipuAIEmbeddings(BaseModel, Embeddings):
http_client: Union[Any, None] = None
"""Optional httpx.Client."""

class Config:
"""Configuration for this pydantic object."""
if PYDANTIC_V2:
model_config: ClassVar[ConfigDict] = ConfigDict(
extra="forbid", populate_by_name=True
)
else:

extra = Extra.forbid
allow_population_by_field_name = True
class Config:
allow_population_by_field_name = True

@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
3 changes: 0 additions & 3 deletions pytest.ini

This file was deleted.

4 changes: 2 additions & 2 deletions tests/assistant/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from fastapi import APIRouter, Body, FastAPI, status
from fastapi.middleware.cors import CORSMiddleware
from langchain.agents import tool
from langchain.tools.shell import ShellTool
from langchain_community.tools import ShellTool
from langchain_core.agents import AgentAction
from pydantic.v1 import BaseModel, Extra, Field
from pydantic.v1 import Extra, Field
from sse_starlette.sse import EventSourceResponse
from uvicorn import Config, Server
from zhipuai.core.logs import (
Expand Down
2 changes: 0 additions & 2 deletions tests/integration_tests/all_tools/test_alltools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

import pytest
from langchain.agents import tool
from langchain.tools.shell import ShellTool
from pydantic.v1 import BaseModel, Extra, Field

from langchain_glm.agent_toolkits import BaseToolOutput
from langchain_glm.agents.zhipuai_all_tools import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_paser_web_browser_success_tool_calls():
{
"title": "昨夜今晨,京津冀发生这些大事(2024年6月27日) - 腾讯网",
"link": "https://new.qq.com/rain/a/20240627A013AI00",
"content": "北京首套房首付比例最低2成. “517”楼市新政的“靴子”在北京落地了。. "
"content": "北京首套房首付比例最低2成. “517”楼市新政的“靴子”在北京落地了。. ",
}
]
},
Expand Down
6 changes: 4 additions & 2 deletions tests/unit_tests/test_code_interpreter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
import sys

import pytest

from langchain_glm.agent_toolkits.all_tools.code_interpreter_tool import (
CodeInterpreterAllToolExecutor,
)
Expand All @@ -14,8 +16,8 @@ def test_python_ast_interpreter():
)
print(out.data)
assert (
out.data
!= """Access:code_interpreter,python_repl_ast, Message: print('Hello, World!')
out.data
!= """Access:code_interpreter,python_repl_ast, Message: print('Hello, World!')
Hello, World!
"""
)
4 changes: 2 additions & 2 deletions tests/unit_tests/tools_bind/test_tools_bind.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
from langchain.agents import tool as register_tool
from langchain.tools.shell import ShellTool
from langchain_community.tools import ShellTool
from langchain_core.runnables import RunnableBinding
from pydantic.v1 import BaseModel, Extra, Field
from pydantic.v1 import Extra, Field

from langchain_glm.agent_toolkits import BaseToolOutput
from langchain_glm.agents.zhipuai_all_tools.base import _get_assistants_tool
Expand Down

0 comments on commit ce72ee3

Please sign in to comment.