Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ODSC-58449: ads LangChain plugin update #877

Merged
merged 5 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ads/llm/langchain/plugins/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from langchain.llms.base import LLM
from langchain.pydantic_v1 import BaseModel, Field, root_validator

from ads import logger
from ads.common.auth import default_signer
from ads.config import COMPARTMENT_OCID

Expand Down Expand Up @@ -95,6 +96,11 @@ def validate_environment( # pylint: disable=no-self-argument
"""Validate that python package exists in environment."""
# Initialize client only if user does not pass in client.
# Users may choose to initialize the OCI client by themselves and pass it into this model.
logger.warning(
f"The ads langchain plugin {cls.__name__} will be deprecated soon. "
"Please refer to https://python.langchain.com/v0.2/docs/integrations/providers/oci/ "
"for the latest support."
)
if not values.get("client"):
auth = values.get("auth", {})
client_kwargs = auth.get("client_kwargs") or {}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ pii = [
"spacy==3.6.1",
"report-creator==1.0.9",
]
llm = ["langchain-community<0.0.32", "langchain>=0.1.10,<0.1.14", "evaluate>=0.4.0", "langchain-core<0.1.51"]
llm = ["langchain-community<0.0.32", "langchain>=0.1.10,<0.1.14", "evaluate>=0.4.0"]
aqua = ["jupyter_server"]

# To reduce backtracking (decrese deps install time) during test/dev env setup reducing number of versions pip is
Expand Down
44 changes: 37 additions & 7 deletions tests/unitary/with_extras/langchain/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,25 @@

import os
from copy import deepcopy
from unittest import TestCase, mock, SkipTest
from unittest import SkipTest, TestCase, mock, skipIf

from langchain.llms import Cohere
import langchain_core
from langchain.chains import LLMChain
from langchain.llms import Cohere
from langchain.prompts import PromptTemplate
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough

from ads.llm.serialize import load, dump
from ads.llm import (
GenerativeAI,
GenerativeAIEmbeddings,
ModelDeploymentTGI,
ModelDeploymentVLLM,
)
from ads.llm.serialize import dump, load


def version_tuple(version):
return tuple(map(int, version.split(".")))


class ChainSerializationTest(TestCase):
Expand Down Expand Up @@ -142,6 +147,10 @@ def test_llm_chain_serialization_with_oci(self):
self.assertEqual(llm_chain.llm.model, "my_model")
self.assertEqual(llm_chain.input_keys, ["subject"])

@skipIf(
version_tuple(langchain_core.__version__) > (0, 1, 50),
"Serialization not supported in this langchain_core version",
)
def test_oci_gen_ai_serialization(self):
"""Tests serialization of OCI Gen AI LLM."""
try:
Expand All @@ -157,6 +166,10 @@ def test_oci_gen_ai_serialization(self):
self.assertEqual(llm.compartment_id, self.COMPARTMENT_ID)
self.assertEqual(llm.client_kwargs, self.GEN_AI_KWARGS)

@skipIf(
version_tuple(langchain_core.__version__) > (0, 1, 50),
"Serialization not supported in this langchain_core version",
)
def test_gen_ai_embeddings_serialization(self):
"""Tests serialization of OCI Gen AI embeddings."""
try:
Expand Down Expand Up @@ -201,10 +214,27 @@ def test_runnable_sequence_serialization(self):
element_3 = kwargs.get("last")
self.assertNotIn("_type", element_3)
self.assertEqual(element_3.get("id"), ["ads", "llm", "ModelDeploymentTGI"])
self.assertEqual(
element_3.get("kwargs"),
{"endpoint": "https://modeldeployment.customer-oci.com/ocid/predict"},
)

if version_tuple(langchain_core.__version__) > (0, 1, 50):
self.assertEqual(
element_3.get("kwargs"),
{
"max_tokens": 256,
"temperature": 0.2,
"p": 0.75,
"endpoint": "https://modeldeployment.customer-oci.com/ocid/predict",
"best_of": 1,
"do_sample": True,
"watermark": True,
},
)
else:
self.assertEqual(
element_3.get("kwargs"),
{
"endpoint": "https://modeldeployment.customer-oci.com/ocid/predict",
},
)

chain = load(serialized)
self.assertEqual(len(chain.steps), 3)
Expand Down
Loading