Skip to content

Commit

Permalink
ODSC-58449: ads LangChain plugin update (#877)
Browse files Browse the repository at this point in the history
  • Loading branch information
mingkang111 authored Jun 18, 2024
1 parent 30534f7 commit 3f5ff91
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 8 deletions.
6 changes: 6 additions & 0 deletions ads/llm/langchain/plugins/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from langchain.llms.base import LLM
from langchain.pydantic_v1 import BaseModel, Field, root_validator

from ads import logger
from ads.common.auth import default_signer
from ads.config import COMPARTMENT_OCID

Expand Down Expand Up @@ -95,6 +96,11 @@ def validate_environment( # pylint: disable=no-self-argument
"""Validate that python package exists in environment."""
# Initialize client only if user does not pass in client.
# Users may choose to initialize the OCI client by themselves and pass it into this model.
logger.warning(
f"The ads langchain plugin {cls.__name__} will be deprecated soon. "
"Please refer to https://python.langchain.com/v0.2/docs/integrations/providers/oci/ "
"for the latest support."
)
if not values.get("client"):
auth = values.get("auth", {})
client_kwargs = auth.get("client_kwargs") or {}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ pii = [
"spacy==3.6.1",
"report-creator==1.0.9",
]
llm = ["langchain-community<0.0.32", "langchain>=0.1.10,<0.1.14", "evaluate>=0.4.0", "langchain-core<0.1.51"]
llm = ["langchain-community<0.0.32", "langchain>=0.1.10,<0.1.14", "evaluate>=0.4.0"]
aqua = ["jupyter_server"]

# To reduce backtracking (decrese deps install time) during test/dev env setup reducing number of versions pip is
Expand Down
44 changes: 37 additions & 7 deletions tests/unitary/with_extras/langchain/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,25 @@

import os
from copy import deepcopy
from unittest import TestCase, mock, SkipTest
from unittest import SkipTest, TestCase, mock, skipIf

from langchain.llms import Cohere
import langchain_core
from langchain.chains import LLMChain
from langchain.llms import Cohere
from langchain.prompts import PromptTemplate
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough

from ads.llm.serialize import load, dump
from ads.llm import (
GenerativeAI,
GenerativeAIEmbeddings,
ModelDeploymentTGI,
ModelDeploymentVLLM,
)
from ads.llm.serialize import dump, load


def version_tuple(version):
return tuple(map(int, version.split(".")))


class ChainSerializationTest(TestCase):
Expand Down Expand Up @@ -142,6 +147,10 @@ def test_llm_chain_serialization_with_oci(self):
self.assertEqual(llm_chain.llm.model, "my_model")
self.assertEqual(llm_chain.input_keys, ["subject"])

@skipIf(
version_tuple(langchain_core.__version__) > (0, 1, 50),
"Serialization not supported in this langchain_core version",
)
def test_oci_gen_ai_serialization(self):
"""Tests serialization of OCI Gen AI LLM."""
try:
Expand All @@ -157,6 +166,10 @@ def test_oci_gen_ai_serialization(self):
self.assertEqual(llm.compartment_id, self.COMPARTMENT_ID)
self.assertEqual(llm.client_kwargs, self.GEN_AI_KWARGS)

@skipIf(
version_tuple(langchain_core.__version__) > (0, 1, 50),
"Serialization not supported in this langchain_core version",
)
def test_gen_ai_embeddings_serialization(self):
"""Tests serialization of OCI Gen AI embeddings."""
try:
Expand Down Expand Up @@ -201,10 +214,27 @@ def test_runnable_sequence_serialization(self):
element_3 = kwargs.get("last")
self.assertNotIn("_type", element_3)
self.assertEqual(element_3.get("id"), ["ads", "llm", "ModelDeploymentTGI"])
self.assertEqual(
element_3.get("kwargs"),
{"endpoint": "https://modeldeployment.customer-oci.com/ocid/predict"},
)

if version_tuple(langchain_core.__version__) > (0, 1, 50):
self.assertEqual(
element_3.get("kwargs"),
{
"max_tokens": 256,
"temperature": 0.2,
"p": 0.75,
"endpoint": "https://modeldeployment.customer-oci.com/ocid/predict",
"best_of": 1,
"do_sample": True,
"watermark": True,
},
)
else:
self.assertEqual(
element_3.get("kwargs"),
{
"endpoint": "https://modeldeployment.customer-oci.com/ocid/predict",
},
)

chain = load(serialized)
self.assertEqual(len(chain.steps), 3)
Expand Down

0 comments on commit 3f5ff91

Please sign in to comment.