From 25511ec535d3974afaaa39cc07ce3666add21652 Mon Sep 17 00:00:00 2001 From: Dudi Zimberknopf Date: Wed, 3 Jul 2024 17:59:54 +0300 Subject: [PATCH] add openai support --- falkordb_gemini_kg/kg.py | 11 +- falkordb_gemini_kg/models/openai.py | 109 ++++++++++++++++++ poetry.lock | 143 +++++++++++++++++++++++- pyproject.toml | 1 + tests/{test_kg.py => test_kg_gemini.py} | 2 +- tests/test_kg_openai.py | 100 +++++++++++++++++ 6 files changed, 362 insertions(+), 4 deletions(-) create mode 100644 falkordb_gemini_kg/models/openai.py rename tests/{test_kg.py => test_kg_gemini.py} (98%) create mode 100644 tests/test_kg_openai.py diff --git a/falkordb_gemini_kg/kg.py b/falkordb_gemini_kg/kg.py index e57cb12..5209bd1 100644 --- a/falkordb_gemini_kg/kg.py +++ b/falkordb_gemini_kg/kg.py @@ -5,6 +5,7 @@ from falkordb_gemini_kg.classes.model_config import KnowledgeGraphModelConfig from falkordb_gemini_kg.steps.extract_data_step import ExtractDataStep from falkordb_gemini_kg.steps.graph_query_step import GraphQueryGenerationStep +from falkordb_gemini_kg.fixtures.prompts import GRAPH_QA_SYSTEM, CYPHER_GEN_SYSTEM from falkordb_gemini_kg.steps.qa_step import QAStep from falkordb_gemini_kg.classes.ChatSession import ChatSession @@ -124,7 +125,11 @@ def ask(self, question: str) -> str: >>> ans = kg.ask("List a few movies in which that actored played in", history) """ - cypher_chat_session = self._model_config.cypher_generation.start_chat() + cypher_chat_session = ( + self._model_config.cypher_generation.with_system_instruction( + CYPHER_GEN_SYSTEM.replace("#ONTOLOGY", str(self.ontology.to_json())), + ).start_chat() + ) cypher_step = GraphQueryGenerationStep( ontology=self.ontology, chat_session=cypher_chat_session, @@ -133,7 +138,9 @@ def ask(self, question: str) -> str: (context, cypher) = cypher_step.run(question) - qa_chat_session = self._model_config.qa.start_chat() + qa_chat_session = self._model_config.qa.with_system_instruction( + GRAPH_QA_SYSTEM + ).start_chat() qa_step = QAStep( chat_session=qa_chat_session, ) diff --git a/falkordb_gemini_kg/models/openai.py b/falkordb_gemini_kg/models/openai.py new file mode 100644 index 0000000..8b4650c --- /dev/null +++ b/falkordb_gemini_kg/models/openai.py @@ -0,0 +1,109 @@ +from .model import * +from openai import OpenAI, completions + + +class OpenAiGenerativeModel(GenerativeModel): + + client: OpenAI = None + + def __init__( + self, + model_name: str, + generation_config: GenerativeModelConfig | None = None, + system_instruction: str | None = None, + ): + self.model_name = model_name + self.generation_config = generation_config + self.system_instruction = system_instruction + + def _get_model(self) -> OpenAI: + if self.client is None: + self.client = OpenAI() + + return self.client + + def with_system_instruction(self, system_instruction: str) -> "GenerativeModel": + self.system_instruction = system_instruction + self.client = None + self._get_model() + + return self + + def start_chat(self, args: dict | None = None) -> GenerativeModelChatSession: + return OpenAiChatSession(self, args) + + def ask(self, message: str) -> GenerationResponse: + response = self.client.chat.completions.create( + model=self.model_name, + messages=[ + {"role": "system", "content": self.system_instruction}, + {"role": "user", "content": message}, + ], + max_tokens=self.generation_config.max_output_tokens, + temperature=self.generation_config.temperature, + top_p=self.generation_config.top_p, + top_k=self.generation_config.top_k, + stop=self.generation_config.stop_sequences, + ) + return self.parse_generate_content_response(response) + + def parse_generate_content_response(self, response: any) -> GenerationResponse: + return GenerationResponse( + text=response.choices[0].message.content, + finish_reason=( + FinishReason.STOP + if response.choices[0].finish_reason == "stop" + else ( + FinishReason.MAX_TOKENS + if response.choices[0].finish_reason == "length" + else FinishReason.OTHER + ) + ), + ) + + +class OpenAiChatSession(GenerativeModelChatSession): + + _history = [] + + def __init__(self, model: OpenAiGenerativeModel, args: dict | None = None): + self._model = model + self._args = args + self._history = ( + [{"role": "system", "content": self._model.system_instruction}] + if self._model.system_instruction is not None + else [] + ) + + def send_message(self, message: str) -> GenerationResponse: + prompt = [] + prompt.extend(self._history) + prompt.append({"role": "user", "content": message}) + response = self._model.client.chat.completions.create( + model=self._model.model_name, + messages=prompt, + max_tokens=( + self._model.generation_config.max_output_tokens + if self._model.generation_config is not None + else None + ), + temperature=( + self._model.generation_config.temperature + if self._model.generation_config is not None + else None + ), + top_p=( + self._model.generation_config.top_p + if self._model.generation_config is not None + else None + ), + stop=( + self._model.generation_config.stop_sequences + if self._model.generation_config is not None + else None + ), + ) + content = self._model.parse_generate_content_response(response) + self._history.append({"role": "user", "content": message}) + self._history.append({"role": "system", "content": content.text}) + return content diff --git a/poetry.lock b/poetry.lock index 087a5ec..c7772f6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -22,6 +22,26 @@ files = [ {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, ] +[[package]] +name = "anyio" +version = "4.4.0" +description = "High level compatibility layer for multiple asynchronous event loop implementations" +optional = false +python-versions = ">=3.8" +files = [ + {file = "anyio-4.4.0-py3-none-any.whl", hash = "sha256:c1b2d8f46a8a812513012e1107cb0e68c17159a7a594208005a57dc776e1bdc7"}, + {file = "anyio-4.4.0.tar.gz", hash = "sha256:5aadc6a1bbb7cdb0bede386cac5e2940f5e2ff3aa20277e991cf028e0585ce94"}, +] + +[package.dependencies] +idna = ">=2.8" +sniffio = ">=1.1" + +[package.extras] +doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] +trio = ["trio (>=0.23)"] + [[package]] name = "babel" version = "2.15.0" @@ -214,6 +234,17 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "distro" +version = "1.9.0" +description = "Distro - an OS platform information API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, + {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, +] + [[package]] name = "docstring-parser" version = "0.16" @@ -710,6 +741,62 @@ googleapis-common-protos = ">=1.5.5" grpcio = ">=1.62.2" protobuf = ">=4.21.6" +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +optional = false +python-versions = ">=3.7" +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + +[[package]] +name = "httpcore" +version = "1.0.5" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"}, + {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.13,<0.15" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<0.26.0)"] + +[[package]] +name = "httpx" +version = "0.27.0" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"}, + {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] + [[package]] name = "idna" version = "3.7" @@ -883,6 +970,29 @@ files = [ {file = "numpy-2.0.0.tar.gz", hash = "sha256:cf5d1c9e6837f8af9f92b6bd3e86d513cdc11f60fd62185cc49ec7d1aba34864"}, ] +[[package]] +name = "openai" +version = "1.35.9" +description = "The official Python library for the openai API" +optional = false +python-versions = ">=3.7.1" +files = [ + {file = "openai-1.35.9-py3-none-any.whl", hash = "sha256:d73d353bcc0bd46b9516e78a0c6fb1cffaaeb92906c7c7b467c4fa088332a150"}, + {file = "openai-1.35.9.tar.gz", hash = "sha256:4f5c1b90526cf48eaedac7b32d11b5c92fa7064b82617ad8f5f3279cd9ef090d"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tqdm = ">4" +typing-extensions = ">=4.7,<5" + +[package.extras] +datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] + [[package]] name = "packaging" version = "24.1" @@ -1382,6 +1492,17 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +[[package]] +name = "sniffio" +version = "1.3.1" +description = "Sniff out which async library your code is running under" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, + {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, +] + [[package]] name = "snowballstemmer" version = "2.2.0" @@ -1652,6 +1773,26 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] pymysql = ["pymysql"] sqlcipher = ["sqlcipher3_binary"] +[[package]] +name = "tqdm" +version = "4.66.4" +description = "Fast, Extensible Progress Meter" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tqdm-4.66.4-py3-none-any.whl", hash = "sha256:b75ca56b413b030bc3f00af51fd2c1a1a5eac6a0c1cca83cbb37a5c52abce644"}, + {file = "tqdm-4.66.4.tar.gz", hash = "sha256:e4d936c9de8727928f3be6079590e97d9abfe8d39a590be678eb5919ffc186bb"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"] +notebook = ["ipywidgets (>=6)"] +slack = ["slack-sdk"] +telegram = ["requests"] + [[package]] name = "typing-extensions" version = "4.12.2" @@ -1731,4 +1872,4 @@ xai = ["tensorflow (>=2.3.0,<3.0.0dev)"] [metadata] lock-version = "2.0" python-versions = "^3.11.4" -content-hash = "ce0d4be4a2614e6cda568a2af1f0068532528e20f01163963cccb7f0f63851b4" +content-hash = "f4afb02d8a5f1042e2c2c170e4de0e4a25303a52e60b8396ca7dd87be1588a48" diff --git a/pyproject.toml b/pyproject.toml index 025975c..6b83308 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ backoff = "^2.2.1" python-abc = "^0.2.0" ratelimit = "^2.2.1" python-dotenv = "^1.0.1" +openai = "^1.35.9" [tool.poetry.group.test.dependencies] pytest = "^8.2.1" diff --git a/tests/test_kg.py b/tests/test_kg_gemini.py similarity index 98% rename from tests/test_kg.py rename to tests/test_kg_gemini.py index 4b01894..86ef4c5 100644 --- a/tests/test_kg.py +++ b/tests/test_kg_gemini.py @@ -20,7 +20,7 @@ vertexai.init(project=os.getenv("PROJECT_ID"), location=os.getenv("REGION")) -class TestKG(unittest.TestCase): +class TestKGGemini(unittest.TestCase): """ Test Knowledge Graph """ diff --git a/tests/test_kg_openai.py b/tests/test_kg_openai.py new file mode 100644 index 0000000..53aab32 --- /dev/null +++ b/tests/test_kg_openai.py @@ -0,0 +1,100 @@ +from dotenv import load_dotenv + +load_dotenv() +from falkordb_gemini_kg.classes.ontology import Ontology +from falkordb_gemini_kg.classes.node import Node +from falkordb_gemini_kg.classes.edge import Edge +from falkordb_gemini_kg.classes.attribute import Attribute, AttributeType +import unittest +from falkordb_gemini_kg.classes.source import Source +from falkordb_gemini_kg.models.openai import OpenAiGenerativeModel +from falkordb_gemini_kg import KnowledgeGraph, KnowledgeGraphModelConfig +import vertexai +import os +import logging +from falkordb import FalkorDB + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +class TestKGOpenAI(unittest.TestCase): + """ + Test Knowledge Graph + """ + + @classmethod + def setUpClass(cls): + + cls.ontology = Ontology([], []) + + cls.ontology.add_node( + Node( + label="Actor", + attributes=[ + Attribute( + name="name", + attr_type=AttributeType.STRING, + unique=True, + required=True, + ), + ], + ) + ) + cls.ontology.add_node( + Node( + label="Movie", + attributes=[ + Attribute( + name="title", + attr_type=AttributeType.STRING, + unique=True, + required=True, + ), + ], + ) + ) + cls.ontology.add_edge( + Edge( + label="ACTED_IN", + source="Actor", + target="Movie", + attributes=[ + Attribute( + name="role", + attr_type=AttributeType.STRING, + unique=False, + required=False, + ), + ], + ) + ) + + model = OpenAiGenerativeModel(model_name="gpt-3.5-turbo-0125") + cls.kg = KnowledgeGraph( + name="IMDB", + ontology=cls.ontology, + model_config=KnowledgeGraphModelConfig.with_model(model), + ) + + def test_kg_creation(self): + + file_path = "tests/data/madoff.txt" + + sources = [Source(file_path)] + + self.kg.process_sources(sources) + + answer = self.kg.ask("List a few actors") + + logger.info(f"Answer: {answer}") + + assert "Joseph Scotto" in answer, "Joseph Scotto not found in answer" + + def test_kg_delete(self): + + self.kg.delete() + + db = FalkorDB() + graphs = db.list_graphs() + self.assertNotIn("IMDB", graphs)