From 3f5ff91bebbb423e66532bbf5e78ff461ed510cc Mon Sep 17 00:00:00 2001 From: MING KANG Date: Tue, 18 Jun 2024 12:18:53 -0400 Subject: [PATCH] ODSC-58449: ads LangChain plugin update (#877) --- ads/llm/langchain/plugins/base.py | 6 +++ pyproject.toml | 2 +- .../langchain/test_serialization.py | 44 ++++++++++++++++--- 3 files changed, 44 insertions(+), 8 deletions(-) diff --git a/ads/llm/langchain/plugins/base.py b/ads/llm/langchain/plugins/base.py index 37788f751..d9f260832 100644 --- a/ads/llm/langchain/plugins/base.py +++ b/ads/llm/langchain/plugins/base.py @@ -8,6 +8,7 @@ from langchain.llms.base import LLM from langchain.pydantic_v1 import BaseModel, Field, root_validator +from ads import logger from ads.common.auth import default_signer from ads.config import COMPARTMENT_OCID @@ -95,6 +96,11 @@ def validate_environment( # pylint: disable=no-self-argument """Validate that python package exists in environment.""" # Initialize client only if user does not pass in client. # Users may choose to initialize the OCI client by themselves and pass it into this model. + logger.warning( + f"The ads langchain plugin {cls.__name__} will be deprecated soon. " + "Please refer to https://python.langchain.com/v0.2/docs/integrations/providers/oci/ " + "for the latest support." + ) if not values.get("client"): auth = values.get("auth", {}) client_kwargs = auth.get("client_kwargs") or {} diff --git a/pyproject.toml b/pyproject.toml index 4e6c5fe48..76b411785 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -175,7 +175,7 @@ pii = [ "spacy==3.6.1", "report-creator==1.0.9", ] -llm = ["langchain-community<0.0.32", "langchain>=0.1.10,<0.1.14", "evaluate>=0.4.0", "langchain-core<0.1.51"] +llm = ["langchain-community<0.0.32", "langchain>=0.1.10,<0.1.14", "evaluate>=0.4.0"] aqua = ["jupyter_server"] # To reduce backtracking (decrese deps install time) during test/dev env setup reducing number of versions pip is diff --git a/tests/unitary/with_extras/langchain/test_serialization.py b/tests/unitary/with_extras/langchain/test_serialization.py index 7a0c582a6..831be24b2 100644 --- a/tests/unitary/with_extras/langchain/test_serialization.py +++ b/tests/unitary/with_extras/langchain/test_serialization.py @@ -7,20 +7,25 @@ import os from copy import deepcopy -from unittest import TestCase, mock, SkipTest +from unittest import SkipTest, TestCase, mock, skipIf -from langchain.llms import Cohere +import langchain_core from langchain.chains import LLMChain +from langchain.llms import Cohere from langchain.prompts import PromptTemplate from langchain.schema.runnable import RunnableParallel, RunnablePassthrough -from ads.llm.serialize import load, dump from ads.llm import ( GenerativeAI, GenerativeAIEmbeddings, ModelDeploymentTGI, ModelDeploymentVLLM, ) +from ads.llm.serialize import dump, load + + +def version_tuple(version): + return tuple(map(int, version.split("."))) class ChainSerializationTest(TestCase): @@ -142,6 +147,10 @@ def test_llm_chain_serialization_with_oci(self): self.assertEqual(llm_chain.llm.model, "my_model") self.assertEqual(llm_chain.input_keys, ["subject"]) + @skipIf( + version_tuple(langchain_core.__version__) > (0, 1, 50), + "Serialization not supported in this langchain_core version", + ) def test_oci_gen_ai_serialization(self): """Tests serialization of OCI Gen AI LLM.""" try: @@ -157,6 +166,10 @@ def test_oci_gen_ai_serialization(self): self.assertEqual(llm.compartment_id, self.COMPARTMENT_ID) self.assertEqual(llm.client_kwargs, self.GEN_AI_KWARGS) + @skipIf( + version_tuple(langchain_core.__version__) > (0, 1, 50), + "Serialization not supported in this langchain_core version", + ) def test_gen_ai_embeddings_serialization(self): """Tests serialization of OCI Gen AI embeddings.""" try: @@ -201,10 +214,27 @@ def test_runnable_sequence_serialization(self): element_3 = kwargs.get("last") self.assertNotIn("_type", element_3) self.assertEqual(element_3.get("id"), ["ads", "llm", "ModelDeploymentTGI"]) - self.assertEqual( - element_3.get("kwargs"), - {"endpoint": "https://modeldeployment.customer-oci.com/ocid/predict"}, - ) + + if version_tuple(langchain_core.__version__) > (0, 1, 50): + self.assertEqual( + element_3.get("kwargs"), + { + "max_tokens": 256, + "temperature": 0.2, + "p": 0.75, + "endpoint": "https://modeldeployment.customer-oci.com/ocid/predict", + "best_of": 1, + "do_sample": True, + "watermark": True, + }, + ) + else: + self.assertEqual( + element_3.get("kwargs"), + { + "endpoint": "https://modeldeployment.customer-oci.com/ocid/predict", + }, + ) chain = load(serialized) self.assertEqual(len(chain.steps), 3)