Skip to content
This repository has been archived by the owner on Sep 19, 2024. It is now read-only.

Commit

Permalink
add openai support
Browse files Browse the repository at this point in the history
  • Loading branch information
dudizimber committed Jul 3, 2024
1 parent 33024f9 commit 25511ec
Show file tree
Hide file tree
Showing 6 changed files with 362 additions and 4 deletions.
11 changes: 9 additions & 2 deletions falkordb_gemini_kg/kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from falkordb_gemini_kg.classes.model_config import KnowledgeGraphModelConfig
from falkordb_gemini_kg.steps.extract_data_step import ExtractDataStep
from falkordb_gemini_kg.steps.graph_query_step import GraphQueryGenerationStep
from falkordb_gemini_kg.fixtures.prompts import GRAPH_QA_SYSTEM, CYPHER_GEN_SYSTEM
from falkordb_gemini_kg.steps.qa_step import QAStep
from falkordb_gemini_kg.classes.ChatSession import ChatSession

Expand Down Expand Up @@ -124,7 +125,11 @@ def ask(self, question: str) -> str:
>>> ans = kg.ask("List a few movies in which that actored played in", history)
"""

cypher_chat_session = self._model_config.cypher_generation.start_chat()
cypher_chat_session = (
self._model_config.cypher_generation.with_system_instruction(
CYPHER_GEN_SYSTEM.replace("#ONTOLOGY", str(self.ontology.to_json())),
).start_chat()
)
cypher_step = GraphQueryGenerationStep(
ontology=self.ontology,
chat_session=cypher_chat_session,
Expand All @@ -133,7 +138,9 @@ def ask(self, question: str) -> str:

(context, cypher) = cypher_step.run(question)

qa_chat_session = self._model_config.qa.start_chat()
qa_chat_session = self._model_config.qa.with_system_instruction(
GRAPH_QA_SYSTEM
).start_chat()
qa_step = QAStep(
chat_session=qa_chat_session,
)
Expand Down
109 changes: 109 additions & 0 deletions falkordb_gemini_kg/models/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from .model import *
from openai import OpenAI, completions


class OpenAiGenerativeModel(GenerativeModel):

client: OpenAI = None

def __init__(
self,
model_name: str,
generation_config: GenerativeModelConfig | None = None,
system_instruction: str | None = None,
):
self.model_name = model_name
self.generation_config = generation_config
self.system_instruction = system_instruction

def _get_model(self) -> OpenAI:
if self.client is None:
self.client = OpenAI()

return self.client

def with_system_instruction(self, system_instruction: str) -> "GenerativeModel":
self.system_instruction = system_instruction
self.client = None
self._get_model()

return self

def start_chat(self, args: dict | None = None) -> GenerativeModelChatSession:
return OpenAiChatSession(self, args)

def ask(self, message: str) -> GenerationResponse:
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": self.system_instruction},
{"role": "user", "content": message},
],
max_tokens=self.generation_config.max_output_tokens,
temperature=self.generation_config.temperature,
top_p=self.generation_config.top_p,
top_k=self.generation_config.top_k,
stop=self.generation_config.stop_sequences,
)
return self.parse_generate_content_response(response)

def parse_generate_content_response(self, response: any) -> GenerationResponse:
return GenerationResponse(
text=response.choices[0].message.content,
finish_reason=(
FinishReason.STOP
if response.choices[0].finish_reason == "stop"
else (
FinishReason.MAX_TOKENS
if response.choices[0].finish_reason == "length"
else FinishReason.OTHER
)
),
)


class OpenAiChatSession(GenerativeModelChatSession):

_history = []

def __init__(self, model: OpenAiGenerativeModel, args: dict | None = None):
self._model = model
self._args = args
self._history = (
[{"role": "system", "content": self._model.system_instruction}]
if self._model.system_instruction is not None
else []
)

def send_message(self, message: str) -> GenerationResponse:
prompt = []
prompt.extend(self._history)
prompt.append({"role": "user", "content": message})
response = self._model.client.chat.completions.create(
model=self._model.model_name,
messages=prompt,
max_tokens=(
self._model.generation_config.max_output_tokens
if self._model.generation_config is not None
else None
),
temperature=(
self._model.generation_config.temperature
if self._model.generation_config is not None
else None
),
top_p=(
self._model.generation_config.top_p
if self._model.generation_config is not None
else None
),
stop=(
self._model.generation_config.stop_sequences
if self._model.generation_config is not None
else None
),
)
content = self._model.parse_generate_content_response(response)
self._history.append({"role": "user", "content": message})
self._history.append({"role": "system", "content": content.text})
return content
143 changes: 142 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ backoff = "^2.2.1"
python-abc = "^0.2.0"
ratelimit = "^2.2.1"
python-dotenv = "^1.0.1"
openai = "^1.35.9"

[tool.poetry.group.test.dependencies]
pytest = "^8.2.1"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_kg.py → tests/test_kg_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
vertexai.init(project=os.getenv("PROJECT_ID"), location=os.getenv("REGION"))


class TestKG(unittest.TestCase):
class TestKGGemini(unittest.TestCase):
"""
Test Knowledge Graph
"""
Expand Down
Loading

0 comments on commit 25511ec

Please sign in to comment.