Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update LLM and ChatModel for OCI Data Science #4

Open
wants to merge 4 commits into
base: langchain-ai-master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docs/docs/integrations/chat/oci_data_science.ipynb
Original file line number Diff line number Diff line change
@@ -21,6 +21,8 @@
"\n",
"[OCI Data Science](https://docs.oracle.com/en-us/iaas/data-science/using/home.htm) is a fully managed and serverless platform for data science teams to build, train, and manage machine learning models in the Oracle Cloud Infrastructure.\n",
"\n",
"> For the latest updates, examples and experimental features, please see [ADS LangChain Integration](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/large_language_model/langchain_models.html).\n",
"\n",
"This notebooks goes over how to use a chat model hosted on a [OCI Data Science Model Deployment Service](https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-about.htm).\n",
"\n",
"For authentication, the [oracle-ads](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html) library is used to automatically load credentials required for invoking the endpoint.\n",
@@ -112,6 +114,7 @@
"outputs": [],
"source": [
"import os\n",
"\n",
"from langchain_community.chat_models import ChatOCIModelDeploymentVLLM\n",
"\n",
"# Set authentication through environment variables\n",
@@ -302,13 +305,12 @@
}
],
"source": [
"import sys\n",
"import os\n",
"import sys\n",
"\n",
"from langchain_community.chat_models import ChatOCIModelDeployment\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [(\"human\", \"List out the 5 states in the United State.\")]\n",
")\n",
@@ -349,8 +351,8 @@
}
],
"source": [
"from langchain_core.pydantic_v1 import BaseModel\n",
"from langchain_community.chat_models import ChatOCIModelDeployment\n",
"from pydantic import BaseModel\n",
"\n",
"\n",
"class Joke(BaseModel):\n",
Original file line number Diff line number Diff line change
@@ -8,6 +8,8 @@
"\n",
"[OCI Data Science](https://docs.oracle.com/en-us/iaas/data-science/using/home.htm) is a fully managed and serverless platform for data science teams to build, train, and manage machine learning models in the Oracle Cloud Infrastructure.\n",
"\n",
"> For the latest updates, examples and experimental features, please see [ADS LangChain Integration](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/large_language_model/langchain_models.html).\n",
"\n",
"This notebooks goes over how to use an LLM hosted on a [OCI Data Science Model Deployment](https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-about.htm).\n",
"\n",
"For authentication, the [oracle-ads](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html) library is used to automatically load credentials required for invoking the endpoint."
65 changes: 53 additions & 12 deletions libs/community/langchain_community/chat_models/oci_data_science.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,8 @@
Optional,
Type,
Union,
Sequence,
Callable,
)

from langchain_core.callbacks import (
@@ -24,20 +26,17 @@
generate_from_stream,
)
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
from langchain_core.tools import BaseTool
from langchain_core.output_parsers import (
JsonOutputParser,
PydanticOutputParser,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.utils.function_calling import convert_to_openai_tool

from langchain_community.adapters.openai import (
convert_dict_to_message,
convert_message_to_dict,
)
from langchain_community.chat_models.openai import _convert_delta_to_message_chunk
from langchain_community.llms.oci_data_science_model_deployment_endpoint import (
from pydantic import BaseModel, Field
from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
DEFAULT_MODEL_NAME,
BaseOCIModelDeployment,
)
@@ -128,8 +127,7 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
.. code-block:: python

from typing import Optional

from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field

class Joke(BaseModel):
setup: str = Field(description="The setup of the joke")
@@ -148,7 +146,7 @@ class Joke(BaseModel):

Customized Usage:

You can inherit from base class and overrwrite the `_process_response`, `_process_stream_response`,
You can inherit from base class and overwrite the `_process_response`, `_process_stream_response`,
`_construct_json_body` for satisfying customized needed.

.. code-block:: python
@@ -188,6 +186,20 @@ def _construct_json_body(self, messages: list, params: dict) -> dict:
"""Stop words to use when generating. Model output is cut off
at the first occurrence of any of these substrings."""

@pre_init
def validate_environment( # pylint: disable=no-self-argument
cls, values: Dict
) -> Dict:
try:
import langchain_openai

except ImportError as ex:
raise ImportError(
"Could not import langchain_openai package. "
"Please install it with `pip install langchain_openai`."
) from ex
return values

@property
def _llm_type(self) -> str:
"""Return type of llm."""
@@ -542,8 +554,10 @@ def _construct_json_body(self, messages: list, params: dict) -> dict:
converted messages and additional parameters.

"""
from langchain_openai.chat_models.base import _convert_message_to_dict

return {
"messages": [convert_message_to_dict(m) for m in messages],
"messages": [_convert_message_to_dict(m) for m in messages],
**params,
}

@@ -568,6 +582,8 @@ def _process_stream_response(
ValueError: If the response JSON is not well-formed or does not
contain the expected structure.
"""
from langchain_openai.chat_models.base import _convert_delta_to_message_chunk

try:
choice = response_json["choices"][0]
if not isinstance(choice, dict):
@@ -606,6 +622,8 @@ def _process_response(self, response_json: dict) -> ChatResult:
contain the expected structure.

"""
from langchain_openai.chat_models.base import _convert_dict_to_message

generations = []
try:
choices = response_json["choices"]
@@ -617,7 +635,7 @@ def _process_response(self, response_json: dict) -> ChatResult:
) from e

for choice in choices:
message = convert_dict_to_message(choice["message"])
message = _convert_dict_to_message(choice["message"])
generation_info = dict(finish_reason=choice.get("finish_reason"))
if "logprobs" in choice:
generation_info["logprobs"] = choice["logprobs"]
@@ -636,6 +654,14 @@ def _process_response(self, response_json: dict) -> ChatResult:
}
return ChatResult(generations=generations, llm_output=llm_output)

def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)


class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
"""OCI large language chat models deployed with vLLM.
@@ -739,6 +765,19 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
"""Whether to add spaces between special tokens in the output.
Defaults to True."""

tool_choice: Optional[str] = None
"""Whether to use tool calling.
Defaults to None, tool calling is disabled.
Tool calling requires model support and vLLM to be configured with `--tool-call-parser`.
Set this to `auto` for the model to determine whether to make tool calls automatically.
Set this to `required` to force the model to always call one or more tools.
"""

chat_template: Optional[str] = None
"""Use customized chat template.
Defaults to None. The chat template from the tokenizer will be used.
"""

@property
def _llm_type(self) -> str:
"""Return type of llm."""
@@ -785,6 +824,8 @@ def _get_model_params(self) -> List[str]:
"top_k",
"top_p",
"use_beam_search",
"tool_choice",
"chat_template",
]


Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@

import aiohttp
import requests
import traceback
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
@@ -22,9 +23,8 @@
from langchain_core.load.serializable import Serializable
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.utils import get_from_dict_or_env, pre_init
from pydantic import Field

from langchain_community.utilities.requests import Requests
from pydantic import Field

logger = logging.getLogger(__name__)

@@ -37,14 +37,10 @@
class TokenExpiredError(Exception):
"""Raises when token expired."""

pass


class ServerError(Exception):
"""Raises when encounter server error when making inference."""

pass


def _create_retry_decorator(
llm: "BaseOCIModelDeployment",
@@ -173,6 +169,7 @@ def _completion_with_retry(**kwargs: Any) -> Any:
except TokenExpiredError as e:
raise e
except Exception as err:
traceback.print_exc()
logger.debug(
f"Requests payload: {data}. Requests arguments: "
f"url={self.endpoint},timeout={request_timeout},stream={stream}. "
@@ -219,6 +216,7 @@ async def _completion_with_retry(**kwargs: Any) -> Any:
except TokenExpiredError as e:
raise e
except Exception as err:
traceback.print_exc()
logger.debug(
f"Requests payload: `{data}`. "
f"Stream mode={stream}. "
@@ -305,13 +303,16 @@ def _parse_stream_line(self, line: bytes) -> Optional[str]:
The processed line as a string if valid, otherwise `None`.
"""
line = line.strip()
if line:
_line = line.decode("utf-8")
if "[DONE]" in _line:
return None
if not line:
return None
_line = line.decode("utf-8")

if _line.lower().startswith("data:"):
_line = _line[5:].lstrip()

if _line.lower().startswith("data:"):
return _line[5:].lstrip()
if _line.startswith("[DONE]"):
return None
return _line
return None

async def _aiter_sse(
@@ -587,11 +588,11 @@ def _stream(
response = self.completion_with_retry(
data=body, run_manager=run_manager, stream=True, **requests_kwargs
)

for line in self._parse_stream(response.iter_lines()):
chunk = self._handle_sse_line(line)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)

yield chunk

async def _astream(
@@ -749,7 +750,7 @@ class OCIModelDeploymentTGI(OCIModelDeploymentLLM):

"""

api: Literal["/generate", "/v1/completions"] = "/generate"
api: Literal["/generate", "/v1/completions"] = "/v1/completions"
"""Api spec."""

frequency_penalty: float = 0.0
Loading
Loading