Skip to content

Commit

Permalink
Updated pr.
Browse files Browse the repository at this point in the history
  • Loading branch information
lu-ohai committed Dec 10, 2024
1 parent 25c26f2 commit 872e41b
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 17 deletions.
26 changes: 21 additions & 5 deletions ads/llm/langchain/plugins/chat_models/oci_data_science.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
Key init args — client params:
auth: dict
ADS auth dictionary for OCI authentication.
headers: Optional[Dict]
default_headers: Optional[Dict]
The headers to be added to the Model Deployment request.
Instantiate:
Expand All @@ -111,7 +111,7 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
"temperature": 0.2,
# other model parameters ...
},
headers={
default_headers={
"route": "/v1/chat/completions",
# other request headers ...
},
Expand Down Expand Up @@ -263,9 +263,6 @@ def _construct_json_body(self, messages: list, params: dict) -> dict:
"""Stop words to use when generating. Model output is cut off
at the first occurrence of any of these substrings."""

headers: Optional[Dict[str, Any]] = {"route": DEFAULT_INFERENCE_ENDPOINT_CHAT}
"""The headers to be added to the Model Deployment request."""

@model_validator(mode="before")
@classmethod
def validate_openai(cls, values: Any) -> Any:
Expand Down Expand Up @@ -300,6 +297,25 @@ def _default_params(self) -> Dict[str, Any]:
"stream": self.streaming,
}

def _headers(
self, is_async: Optional[bool] = False, body: Optional[dict] = None
) -> Dict:
"""Construct and return the headers for a request.
Args:
is_async (bool, optional): Indicates if the request is asynchronous.
Defaults to `False`.
body (optional): The request body to be included in the headers if
the request is asynchronous.
Returns:
Dict: A dictionary containing the appropriate headers for the request.
"""
return {
"route": DEFAULT_INFERENCE_ENDPOINT_CHAT,
**super()._headers(is_async=is_async, body=body),
}

def _generate(
self,
messages: List[BaseMessage],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class BaseOCIModelDeployment(Serializable):
max_retries: int = 3
"""Maximum number of retries to make when generating."""

headers: Optional[Dict[str, Any]] = {"route": DEFAULT_INFERENCE_ENDPOINT}
default_headers: Optional[Dict[str, Any]] = None
"""The headers to be added to the Model Deployment request."""

@model_validator(mode="before")
Expand Down Expand Up @@ -127,7 +127,7 @@ def _headers(
Returns:
Dict: A dictionary containing the appropriate headers for the request.
"""
headers = self.headers
headers = self.default_headers or {}
if is_async:
signer = self.auth["signer"]
_req = requests.Request("POST", self.endpoint, json=body)
Expand Down Expand Up @@ -485,6 +485,25 @@ def _identifying_params(self) -> Dict[str, Any]:
**self._default_params,
}

def _headers(
self, is_async: Optional[bool] = False, body: Optional[dict] = None
) -> Dict:
"""Construct and return the headers for a request.
Args:
is_async (bool, optional): Indicates if the request is asynchronous.
Defaults to `False`.
body (optional): The request body to be included in the headers if
the request is asynchronous.
Returns:
Dict: A dictionary containing the appropriate headers for the request.
"""
return {
"route": DEFAULT_INFERENCE_ENDPOINT,
**super()._headers(is_async=is_async, body=body),
}

def _generate(
self,
prompts: List[str],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
CONST_PROMPT = "This is a prompt."
CONST_COMPLETION = "This is a completion."
CONST_ENDPOINT = "/v1/chat/completions"
CONST_COMPLETION_ROUTE = "/v1/chat/completions"
CONST_COMPLETION_RESPONSE = {
"id": "chat-123456789",
"object": "chat.completion",
Expand Down Expand Up @@ -124,7 +124,7 @@ def mocked_requests_post(url: str, **kwargs: Any) -> MockResponse:
def test_invoke_vllm(*args: Any) -> None:
"""Tests invoking vLLM endpoint."""
llm = ChatOCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
assert llm.headers == {"route": CONST_ENDPOINT}
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = llm.invoke(CONST_PROMPT)
assert isinstance(output, AIMessage)
assert output.content == CONST_COMPLETION
Expand All @@ -137,7 +137,7 @@ def test_invoke_vllm(*args: Any) -> None:
def test_invoke_tgi(*args: Any) -> None:
"""Tests invoking TGI endpoint using OpenAI Spec."""
llm = ChatOCIModelDeploymentTGI(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
assert llm.headers == {"route": CONST_ENDPOINT}
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = llm.invoke(CONST_PROMPT)
assert isinstance(output, AIMessage)
assert output.content == CONST_COMPLETION
Expand All @@ -152,7 +152,7 @@ def test_stream_vllm(*args: Any) -> None:
llm = ChatOCIModelDeploymentVLLM(
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
)
assert llm.headers == {"route": CONST_ENDPOINT}
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = None
count = 0
for chunk in llm.stream(CONST_PROMPT):
Expand Down Expand Up @@ -191,7 +191,7 @@ async def test_stream_async(*args: Any) -> None:
llm = ChatOCIModelDeploymentVLLM(
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
)
assert llm.headers == {"route": CONST_ENDPOINT}
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
with mock.patch.object(
llm,
"_aiter_sse",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
CONST_PROMPT = "This is a prompt."
CONST_COMPLETION = "This is a completion."
CONST_ENDPOINT = "/v1/completions"
CONST_COMPLETION_ROUTE = "/v1/completions"
CONST_COMPLETION_RESPONSE = {
"choices": [
{
Expand Down Expand Up @@ -117,7 +117,7 @@ async def mocked_async_streaming_response(
def test_invoke_vllm(*args: Any) -> None:
"""Tests invoking vLLM endpoint."""
llm = OCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
assert llm.headers == {"route": CONST_ENDPOINT}
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = llm.invoke(CONST_PROMPT)
assert output == CONST_COMPLETION

Expand All @@ -130,7 +130,7 @@ def test_stream_tgi(*args: Any) -> None:
llm = OCIModelDeploymentTGI(
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
)
assert llm.headers == {"route": CONST_ENDPOINT}
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = ""
count = 0
for chunk in llm.stream(CONST_PROMPT):
Expand All @@ -148,7 +148,7 @@ def test_generate_tgi(*args: Any) -> None:
llm = OCIModelDeploymentTGI(
endpoint=CONST_ENDPOINT, api="/generate", model=CONST_MODEL_NAME
)
assert llm.headers == {"route": CONST_ENDPOINT}
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = llm.invoke(CONST_PROMPT)
assert output == CONST_COMPLETION

Expand All @@ -167,7 +167,7 @@ async def test_stream_async(*args: Any) -> None:
llm = OCIModelDeploymentTGI(
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
)
assert llm.headers == {"route": CONST_ENDPOINT}
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
with mock.patch.object(
llm,
"_aiter_sse",
Expand Down

0 comments on commit 872e41b

Please sign in to comment.