Skip to content

Commit

Permalink
Merge branch 'main' into dataflow_changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ahosler authored Dec 18, 2024
2 parents 21a813d + de214d3 commit 8149235
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 34 deletions.
18 changes: 9 additions & 9 deletions .github/workflows/run-unittests-py39-py310.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,16 @@ jobs:
name: "Test env setup"
timeout-minutes: 30

- name: "Run hpo tests"
timeout-minutes: 10
shell: bash
if: ${{ matrix.name }} == "unitary"
run: |
set -x # print commands that are executed
# - name: "Run hpo tests"
# timeout-minutes: 10
# shell: bash
# if: ${{ matrix.name }} == "unitary"
# run: |
# set -x # print commands that are executed

# Run hpo tests, which hangs if run together with all unitary tests
python -m pytest -v -p no:warnings -n auto --dist loadfile \
tests/unitary/with_extras/hpo
# # Run hpo tests, which hangs if run together with all unitary tests
# python -m pytest -v -p no:warnings -n auto --dist loadfile \
# tests/unitary/with_extras/hpo

- name: "Run unitary tests folder with maximum ADS dependencies"
timeout-minutes: 60
Expand Down
11 changes: 6 additions & 5 deletions ads/llm/guardrails/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2023 Oracle and/or its affiliates.
# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/


import datetime
import functools
import operator
import importlib.util
import operator
import sys
from typing import Any, List, Optional, Union

from typing import Any, List, Dict, Tuple
from langchain.schema.prompt import PromptValue
from langchain.tools.base import BaseTool, ToolException
from pydantic import BaseModel, model_validator
Expand Down Expand Up @@ -207,7 +206,9 @@ def _preprocess(self, input: Any) -> str:
return input.to_string()
return str(input)

def _to_args_and_kwargs(self, tool_input: Any) -> Tuple[Tuple, Dict]:
def _to_args_and_kwargs(
self, tool_input: Union[str, dict], tool_call_id: Optional[str]
) -> tuple[tuple, dict]:
if isinstance(tool_input, dict):
return (), tool_input
else:
Expand Down
43 changes: 34 additions & 9 deletions ads/llm/langchain/plugins/chat_models/oci_data_science.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2023 Oracle and/or its affiliates.
# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
"""Chat model for OCI data science model deployment endpoint."""

Expand Down Expand Up @@ -50,6 +49,7 @@
)

logger = logging.getLogger(__name__)
DEFAULT_INFERENCE_ENDPOINT_CHAT = "/v1/chat/completions"


def _is_pydantic_class(obj: Any) -> bool:
Expand Down Expand Up @@ -93,6 +93,8 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
Key init args — client params:
auth: dict
ADS auth dictionary for OCI authentication.
default_headers: Optional[Dict]
The headers to be added to the Model Deployment request.
Instantiate:
.. code-block:: python
Expand All @@ -109,6 +111,10 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
"temperature": 0.2,
# other model parameters ...
},
default_headers={
"route": "/v1/chat/completions",
# other request headers ...
},
)
Invocation:
Expand Down Expand Up @@ -291,6 +297,25 @@ def _default_params(self) -> Dict[str, Any]:
"stream": self.streaming,
}

def _headers(
self, is_async: Optional[bool] = False, body: Optional[dict] = None
) -> Dict:
"""Construct and return the headers for a request.
Args:
is_async (bool, optional): Indicates if the request is asynchronous.
Defaults to `False`.
body (optional): The request body to be included in the headers if
the request is asynchronous.
Returns:
Dict: A dictionary containing the appropriate headers for the request.
"""
return {
"route": DEFAULT_INFERENCE_ENDPOINT_CHAT,
**super()._headers(is_async=is_async, body=body),
}

def _generate(
self,
messages: List[BaseMessage],
Expand Down Expand Up @@ -704,7 +729,7 @@ def _process_response(self, response_json: dict) -> ChatResult:

for choice in choices:
message = _convert_dict_to_message(choice["message"])
generation_info = dict(finish_reason=choice.get("finish_reason"))
generation_info = {"finish_reason": choice.get("finish_reason")}
if "logprobs" in choice:
generation_info["logprobs"] = choice["logprobs"]

Expand Down Expand Up @@ -794,7 +819,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
"""Number of most likely tokens to consider at each step."""

min_p: Optional[float] = 0.0
"""Float that represents the minimum probability for a token to be considered.
"""Float that represents the minimum probability for a token to be considered.
Must be in [0,1]. 0 to disable this."""

repetition_penalty: Optional[float] = 1.0
Expand All @@ -818,7 +843,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
the EOS token is generated."""

min_tokens: Optional[int] = 0
"""Minimum number of tokens to generate per output sequence before
"""Minimum number of tokens to generate per output sequence before
EOS or stop_token_ids can be generated"""

stop_token_ids: Optional[List[int]] = None
Expand All @@ -836,7 +861,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
tool_choice: Optional[str] = None
"""Whether to use tool calling.
Defaults to None, tool calling is disabled.
Tool calling requires model support and the vLLM to be configured
Tool calling requires model support and the vLLM to be configured
with `--tool-call-parser`.
Set this to `auto` for the model to make tool calls automatically.
Set this to `required` to force the model to always call one or more tools.
Expand Down Expand Up @@ -956,9 +981,9 @@ class ChatOCIModelDeploymentTGI(ChatOCIModelDeployment):
"""Total probability mass of tokens to consider at each step."""

top_logprobs: Optional[int] = None
"""An integer between 0 and 5 specifying the number of most
likely tokens to return at each token position, each with an
associated log probability. logprobs must be set to true if
"""An integer between 0 and 5 specifying the number of most
likely tokens to return at each token position, each with an
associated log probability. logprobs must be set to true if
this parameter is used."""

@property
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2023 Oracle and/or its affiliates.
# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/


Expand All @@ -24,6 +23,7 @@

import aiohttp
import requests
from langchain_community.utilities.requests import Requests
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
Expand All @@ -34,14 +34,13 @@
from langchain_core.utils import get_from_dict_or_env
from pydantic import Field, model_validator

from langchain_community.utilities.requests import Requests

logger = logging.getLogger(__name__)


DEFAULT_TIME_OUT = 300
DEFAULT_CONTENT_TYPE_JSON = "application/json"
DEFAULT_MODEL_NAME = "odsc-llm"
DEFAULT_INFERENCE_ENDPOINT = "/v1/completions"


class TokenExpiredError(Exception):
Expand Down Expand Up @@ -86,6 +85,9 @@ class BaseOCIModelDeployment(Serializable):
max_retries: int = 3
"""Maximum number of retries to make when generating."""

default_headers: Optional[Dict[str, Any]] = None
"""The headers to be added to the Model Deployment request."""

@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Dict:
Expand All @@ -101,7 +103,7 @@ def validate_environment(cls, values: Dict) -> Dict:
"Please install it with `pip install oracle_ads`."
) from ex

if not values.get("auth", None):
if not values.get("auth"):
values["auth"] = ads.common.auth.default_signer()

values["endpoint"] = get_from_dict_or_env(
Expand All @@ -125,12 +127,12 @@ def _headers(
Returns:
Dict: A dictionary containing the appropriate headers for the request.
"""
headers = self.default_headers or {}
if is_async:
signer = self.auth["signer"]
_req = requests.Request("POST", self.endpoint, json=body)
req = _req.prepare()
req = signer(req)
headers = {}
for key, value in req.headers.items():
headers[key] = value

Expand All @@ -140,7 +142,7 @@ def _headers(
)
return headers

return (
headers.update(
{
"Content-Type": DEFAULT_CONTENT_TYPE_JSON,
"enable-streaming": "true",
Expand All @@ -152,6 +154,8 @@ def _headers(
}
)

return headers

def completion_with_retry(
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any:
Expand Down Expand Up @@ -357,7 +361,7 @@ def _refresh_signer(self) -> bool:
self.auth["signer"].refresh_security_token()
return True
return False

@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by LangChain."""
Expand Down Expand Up @@ -388,6 +392,10 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
model="odsc-llm",
streaming=True,
model_kwargs={"frequency_penalty": 1.0},
headers={
"route": "/v1/completions",
# other request headers ...
}
)
llm.invoke("tell me a joke.")
Expand Down Expand Up @@ -477,6 +485,25 @@ def _identifying_params(self) -> Dict[str, Any]:
**self._default_params,
}

def _headers(
self, is_async: Optional[bool] = False, body: Optional[dict] = None
) -> Dict:
"""Construct and return the headers for a request.
Args:
is_async (bool, optional): Indicates if the request is asynchronous.
Defaults to `False`.
body (optional): The request body to be included in the headers if
the request is asynchronous.
Returns:
Dict: A dictionary containing the appropriate headers for the request.
"""
return {
"route": DEFAULT_INFERENCE_ENDPOINT,
**super()._headers(is_async=is_async, body=body),
}

def _generate(
self,
prompts: List[str],
Expand Down Expand Up @@ -712,9 +739,9 @@ def _process_response(self, response_json: dict) -> List[Generation]:
def _generate_info(self, choice: dict) -> Any:
"""Extracts generation info from the response."""
gen_info = {}
finish_reason = choice.get("finish_reason", None)
logprobs = choice.get("logprobs", None)
index = choice.get("index", None)
finish_reason = choice.get("finish_reason")
logprobs = choice.get("logprobs")
index = choice.get("index")
if finish_reason:
gen_info.update({"finish_reason": finish_reason})
if logprobs is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
CONST_PROMPT = "This is a prompt."
CONST_COMPLETION = "This is a completion."
CONST_COMPLETION_ROUTE = "/v1/chat/completions"
CONST_COMPLETION_RESPONSE = {
"id": "chat-123456789",
"object": "chat.completion",
Expand Down Expand Up @@ -123,6 +124,7 @@ def mocked_requests_post(url: str, **kwargs: Any) -> MockResponse:
def test_invoke_vllm(*args: Any) -> None:
"""Tests invoking vLLM endpoint."""
llm = ChatOCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = llm.invoke(CONST_PROMPT)
assert isinstance(output, AIMessage)
assert output.content == CONST_COMPLETION
Expand All @@ -135,6 +137,7 @@ def test_invoke_vllm(*args: Any) -> None:
def test_invoke_tgi(*args: Any) -> None:
"""Tests invoking TGI endpoint using OpenAI Spec."""
llm = ChatOCIModelDeploymentTGI(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = llm.invoke(CONST_PROMPT)
assert isinstance(output, AIMessage)
assert output.content == CONST_COMPLETION
Expand All @@ -149,6 +152,7 @@ def test_stream_vllm(*args: Any) -> None:
llm = ChatOCIModelDeploymentVLLM(
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = None
count = 0
for chunk in llm.stream(CONST_PROMPT):
Expand Down Expand Up @@ -187,6 +191,7 @@ async def test_stream_async(*args: Any) -> None:
llm = ChatOCIModelDeploymentVLLM(
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
with mock.patch.object(
llm,
"_aiter_sse",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
CONST_PROMPT = "This is a prompt."
CONST_COMPLETION = "This is a completion."
CONST_COMPLETION_ROUTE = "/v1/completions"
CONST_COMPLETION_RESPONSE = {
"choices": [
{
Expand Down Expand Up @@ -116,6 +117,7 @@ async def mocked_async_streaming_response(
def test_invoke_vllm(*args: Any) -> None:
"""Tests invoking vLLM endpoint."""
llm = OCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = llm.invoke(CONST_PROMPT)
assert output == CONST_COMPLETION

Expand All @@ -128,6 +130,7 @@ def test_stream_tgi(*args: Any) -> None:
llm = OCIModelDeploymentTGI(
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = ""
count = 0
for chunk in llm.stream(CONST_PROMPT):
Expand All @@ -145,6 +148,7 @@ def test_generate_tgi(*args: Any) -> None:
llm = OCIModelDeploymentTGI(
endpoint=CONST_ENDPOINT, api="/generate", model=CONST_MODEL_NAME
)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = llm.invoke(CONST_PROMPT)
assert output == CONST_COMPLETION

Expand All @@ -163,6 +167,7 @@ async def test_stream_async(*args: Any) -> None:
llm = OCIModelDeploymentTGI(
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
with mock.patch.object(
llm,
"_aiter_sse",
Expand Down

0 comments on commit 8149235

Please sign in to comment.