diff --git a/ads/llm/langchain/plugins/chat_models/oci_data_science.py b/ads/llm/langchain/plugins/chat_models/oci_data_science.py index 9785dd0fd..b2b0e5f16 100644 --- a/ads/llm/langchain/plugins/chat_models/oci_data_science.py +++ b/ads/llm/langchain/plugins/chat_models/oci_data_science.py @@ -93,7 +93,7 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment): Key init args — client params: auth: dict ADS auth dictionary for OCI authentication. - headers: Optional[Dict] + default_headers: Optional[Dict] The headers to be added to the Model Deployment request. Instantiate: @@ -111,7 +111,7 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment): "temperature": 0.2, # other model parameters ... }, - headers={ + default_headers={ "route": "/v1/chat/completions", # other request headers ... }, @@ -263,9 +263,6 @@ def _construct_json_body(self, messages: list, params: dict) -> dict: """Stop words to use when generating. Model output is cut off at the first occurrence of any of these substrings.""" - headers: Optional[Dict[str, Any]] = {"route": DEFAULT_INFERENCE_ENDPOINT_CHAT} - """The headers to be added to the Model Deployment request.""" - @model_validator(mode="before") @classmethod def validate_openai(cls, values: Any) -> Any: @@ -300,6 +297,25 @@ def _default_params(self) -> Dict[str, Any]: "stream": self.streaming, } + def _headers( + self, is_async: Optional[bool] = False, body: Optional[dict] = None + ) -> Dict: + """Construct and return the headers for a request. + + Args: + is_async (bool, optional): Indicates if the request is asynchronous. + Defaults to `False`. + body (optional): The request body to be included in the headers if + the request is asynchronous. + + Returns: + Dict: A dictionary containing the appropriate headers for the request. + """ + return { + "route": DEFAULT_INFERENCE_ENDPOINT_CHAT, + **super()._headers(is_async=is_async, body=body), + } + def _generate( self, messages: List[BaseMessage], diff --git a/ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py b/ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py index 3be494270..20f1fae07 100644 --- a/ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +++ b/ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py @@ -85,7 +85,7 @@ class BaseOCIModelDeployment(Serializable): max_retries: int = 3 """Maximum number of retries to make when generating.""" - headers: Optional[Dict[str, Any]] = {"route": DEFAULT_INFERENCE_ENDPOINT} + default_headers: Optional[Dict[str, Any]] = None """The headers to be added to the Model Deployment request.""" @model_validator(mode="before") @@ -127,7 +127,7 @@ def _headers( Returns: Dict: A dictionary containing the appropriate headers for the request. """ - headers = self.headers + headers = self.default_headers or {} if is_async: signer = self.auth["signer"] _req = requests.Request("POST", self.endpoint, json=body) @@ -485,6 +485,25 @@ def _identifying_params(self) -> Dict[str, Any]: **self._default_params, } + def _headers( + self, is_async: Optional[bool] = False, body: Optional[dict] = None + ) -> Dict: + """Construct and return the headers for a request. + + Args: + is_async (bool, optional): Indicates if the request is asynchronous. + Defaults to `False`. + body (optional): The request body to be included in the headers if + the request is asynchronous. + + Returns: + Dict: A dictionary containing the appropriate headers for the request. + """ + return { + "route": DEFAULT_INFERENCE_ENDPOINT, + **super()._headers(is_async=is_async, body=body), + } + def _generate( self, prompts: List[str], diff --git a/tests/unitary/with_extras/langchain/chat_models/test_oci_data_science.py b/tests/unitary/with_extras/langchain/chat_models/test_oci_data_science.py index e7d32d2de..197171353 100644 --- a/tests/unitary/with_extras/langchain/chat_models/test_oci_data_science.py +++ b/tests/unitary/with_extras/langchain/chat_models/test_oci_data_science.py @@ -26,7 +26,7 @@ CONST_ENDPOINT = "https://oci.endpoint/ocid/predict" CONST_PROMPT = "This is a prompt." CONST_COMPLETION = "This is a completion." -CONST_ENDPOINT = "/v1/chat/completions" +CONST_COMPLETION_ROUTE = "/v1/chat/completions" CONST_COMPLETION_RESPONSE = { "id": "chat-123456789", "object": "chat.completion", @@ -124,7 +124,7 @@ def mocked_requests_post(url: str, **kwargs: Any) -> MockResponse: def test_invoke_vllm(*args: Any) -> None: """Tests invoking vLLM endpoint.""" llm = ChatOCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME) - assert llm.headers == {"route": CONST_ENDPOINT} + assert llm._headers().get("route") == CONST_COMPLETION_ROUTE output = llm.invoke(CONST_PROMPT) assert isinstance(output, AIMessage) assert output.content == CONST_COMPLETION @@ -137,7 +137,7 @@ def test_invoke_vllm(*args: Any) -> None: def test_invoke_tgi(*args: Any) -> None: """Tests invoking TGI endpoint using OpenAI Spec.""" llm = ChatOCIModelDeploymentTGI(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME) - assert llm.headers == {"route": CONST_ENDPOINT} + assert llm._headers().get("route") == CONST_COMPLETION_ROUTE output = llm.invoke(CONST_PROMPT) assert isinstance(output, AIMessage) assert output.content == CONST_COMPLETION @@ -152,7 +152,7 @@ def test_stream_vllm(*args: Any) -> None: llm = ChatOCIModelDeploymentVLLM( endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True ) - assert llm.headers == {"route": CONST_ENDPOINT} + assert llm._headers().get("route") == CONST_COMPLETION_ROUTE output = None count = 0 for chunk in llm.stream(CONST_PROMPT): @@ -191,7 +191,7 @@ async def test_stream_async(*args: Any) -> None: llm = ChatOCIModelDeploymentVLLM( endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True ) - assert llm.headers == {"route": CONST_ENDPOINT} + assert llm._headers().get("route") == CONST_COMPLETION_ROUTE with mock.patch.object( llm, "_aiter_sse", diff --git a/tests/unitary/with_extras/langchain/llms/test_oci_model_deployment_endpoint.py b/tests/unitary/with_extras/langchain/llms/test_oci_model_deployment_endpoint.py index 14cdda3ec..825c49d87 100644 --- a/tests/unitary/with_extras/langchain/llms/test_oci_model_deployment_endpoint.py +++ b/tests/unitary/with_extras/langchain/llms/test_oci_model_deployment_endpoint.py @@ -24,7 +24,7 @@ CONST_ENDPOINT = "https://oci.endpoint/ocid/predict" CONST_PROMPT = "This is a prompt." CONST_COMPLETION = "This is a completion." -CONST_ENDPOINT = "/v1/completions" +CONST_COMPLETION_ROUTE = "/v1/completions" CONST_COMPLETION_RESPONSE = { "choices": [ { @@ -117,7 +117,7 @@ async def mocked_async_streaming_response( def test_invoke_vllm(*args: Any) -> None: """Tests invoking vLLM endpoint.""" llm = OCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME) - assert llm.headers == {"route": CONST_ENDPOINT} + assert llm._headers().get("route") == CONST_COMPLETION_ROUTE output = llm.invoke(CONST_PROMPT) assert output == CONST_COMPLETION @@ -130,7 +130,7 @@ def test_stream_tgi(*args: Any) -> None: llm = OCIModelDeploymentTGI( endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True ) - assert llm.headers == {"route": CONST_ENDPOINT} + assert llm._headers().get("route") == CONST_COMPLETION_ROUTE output = "" count = 0 for chunk in llm.stream(CONST_PROMPT): @@ -148,7 +148,7 @@ def test_generate_tgi(*args: Any) -> None: llm = OCIModelDeploymentTGI( endpoint=CONST_ENDPOINT, api="/generate", model=CONST_MODEL_NAME ) - assert llm.headers == {"route": CONST_ENDPOINT} + assert llm._headers().get("route") == CONST_COMPLETION_ROUTE output = llm.invoke(CONST_PROMPT) assert output == CONST_COMPLETION @@ -167,7 +167,7 @@ async def test_stream_async(*args: Any) -> None: llm = OCIModelDeploymentTGI( endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True ) - assert llm.headers == {"route": CONST_ENDPOINT} + assert llm._headers().get("route") == CONST_COMPLETION_ROUTE with mock.patch.object( llm, "_aiter_sse",