diff --git a/.github/workflows/spellcheck.yml b/.github/workflows/spellcheck.yml new file mode 100644 index 0000000..5dc182d --- /dev/null +++ b/.github/workflows/spellcheck.yml @@ -0,0 +1,17 @@ +name: spellcheck +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] +jobs: + check-spelling: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Check Spelling + uses: rojopolis/spellcheck-github-actions@0.33.1 + with: + config_path: .spellcheck.yml + task_name: Markdown diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..39bc990 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,58 @@ +name: Run tests + +on: + push: + pull_request: + +jobs: + test: + permissions: + contents: "read" + id-token: "write" + runs-on: ubuntu-latest + services: + falkordb: + image: falkordb/falkordb:latest + ports: + - 6379:6379 + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: 3.x # Update with desired Python version + + - name: Cache Poetry virtualenv + id: cache + uses: actions/cache@v4 + with: + path: ~/.poetry/virtualenvs + key: ${{ runner.os }}-poetry-${{ hashFiles('poetry.lock') }} + restore-keys: | + ${{ runner.os }}-poetry- + + - id: "auth" + uses: "google-github-actions/auth@v2" + with: + credentials_json: ${{ secrets.GCP_SA_KEY }} + + - name: "Set up Cloud SDK" + uses: "google-github-actions/setup-gcloud@v2" + with: + version: ">= 363.0.0" + + - name: Install Poetry + if: steps.cache.outputs.cache-hit != true + run: | + curl -sSL https://install.python-poetry.org | python3 - + + - name: Install dependencies + run: poetry install + + - name: Run tests + env: + PROJECT_ID: ${{ secrets.PROJECT_ID }} + REGION: ${{ vars.REGION }} + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + run: poetry run pytest diff --git a/config.yaml b/config.yaml deleted file mode 100644 index a225d69..0000000 --- a/config.yaml +++ /dev/null @@ -1,35 +0,0 @@ -ontology: - model_name: "gemini-1.5-flash-001" - max_output_tokens: 8192 - max_input_characters: 500000 - temperature: 1.5 - top_p: 0.1 - min_urls_count: 5 - output_file: output/ontology.json - cache_file: cache/scrape_cache.json - max_workers: 16 - -extract_data: - model_name: "gemini-1.5-flash-001" - max_output_tokens: 8192 - max_input_characters: 500000 - temperature: 1.5 - top_p: 0.1 - ontology_file: output/ontology.json - cache_file: cache/scrape_cache.json - max_workers: 16 - -query_graph: - ontology_file: output/ontology.json - cypher: - model_name: "gemini-1.5-flash-001" - max_output_tokens: 8192 - max_input_characters: 500000 - temperature: 1.5 - top_p: 0.1 - qa: - model_name: "gemini-1.5-flash-001" - max_output_tokens: 8192 - max_input_characters: 500000 - temperature: 1.5 - top_p: 0.1 diff --git a/falkordb_gemini_kg/__init__.py b/falkordb_gemini_kg/__init__.py index 3e99b55..426e6a8 100644 --- a/falkordb_gemini_kg/__init__.py +++ b/falkordb_gemini_kg/__init__.py @@ -1,7 +1,7 @@ from .classes.source import Source from .classes.ontology import Ontology from .kg import KnowledgeGraph -from .classes.model_config import KnowledgeGraphModelConfig, StepModelConfig +from .classes.model_config import KnowledgeGraphModelConfig from .steps.create_ontology_step import CreateOntologyStep # Setup Null handler diff --git a/falkordb_gemini_kg/classes/__init__.py b/falkordb_gemini_kg/classes/__init__.py index 40e00ee..7583b6a 100644 --- a/falkordb_gemini_kg/classes/__init__.py +++ b/falkordb_gemini_kg/classes/__init__.py @@ -2,8 +2,10 @@ from .source import Source from .node import Node from .edge import Edge +from .attribute import Attribute, AttributeType # Setup Null handler import logging + logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) diff --git a/falkordb_gemini_kg/classes/attribute.py b/falkordb_gemini_kg/classes/attribute.py index 0e730e2..5b75059 100644 --- a/falkordb_gemini_kg/classes/attribute.py +++ b/falkordb_gemini_kg/classes/attribute.py @@ -5,7 +5,7 @@ logger = logging.getLogger(__name__) -class _AttributeType: +class AttributeType: STRING = "string" NUMBER = "number" BOOLEAN = "boolean" @@ -13,15 +13,15 @@ class _AttributeType: @staticmethod def fromString(txt: str): if txt.isdigit(): - return _AttributeType.NUMBER + return AttributeType.NUMBER elif txt.lower() in ["true", "false"]: - return _AttributeType.BOOLEAN - return _AttributeType.STRING + return AttributeType.BOOLEAN + return AttributeType.STRING class Attribute: def __init__( - self, name: str, attr_type: _AttributeType, unique: bool, required: bool = False + self, name: str, attr_type: AttributeType, unique: bool, required: bool = False ): self.name = name self.type = attr_type @@ -32,9 +32,9 @@ def __init__( def from_json(txt: str): txt = txt if isinstance(txt, dict) else json.loads(txt) if txt["type"] not in [ - _AttributeType.STRING, - _AttributeType.NUMBER, - _AttributeType.BOOLEAN, + AttributeType.STRING, + AttributeType.NUMBER, + AttributeType.BOOLEAN, ]: raise Exception(f"Invalid attribute type: {txt['type']}") return Attribute( @@ -52,9 +52,9 @@ def from_string(txt: str): required = "*" in txt if attr_type not in [ - _AttributeType.STRING, - _AttributeType.NUMBER, - _AttributeType.BOOLEAN, + AttributeType.STRING, + AttributeType.NUMBER, + AttributeType.BOOLEAN, ]: raise Exception(f"Invalid attribute type: {attr_type}") diff --git a/falkordb_gemini_kg/classes/edge.py b/falkordb_gemini_kg/classes/edge.py index 5a00044..d92f2e1 100644 --- a/falkordb_gemini_kg/classes/edge.py +++ b/falkordb_gemini_kg/classes/edge.py @@ -1,7 +1,7 @@ import json import re import logging -from .attribute import Attribute, _AttributeType +from .attribute import Attribute, AttributeType from falkordb import Node as GraphNode, Edge as GraphEdge from falkordb_gemini_kg.fixtures.regex import * @@ -58,7 +58,7 @@ def from_graph(edge: GraphEdge, nodes: list[GraphNode]): [ Attribute( attr, - _AttributeType.fromString(edge.properties), + AttributeType.fromString(edge.properties), "!" in edge.properties[attr], "*" in edge.properties[attr], ) diff --git a/falkordb_gemini_kg/classes/model_config.py b/falkordb_gemini_kg/classes/model_config.py index d1e3af0..d567971 100644 --- a/falkordb_gemini_kg/classes/model_config.py +++ b/falkordb_gemini_kg/classes/model_config.py @@ -1,66 +1,22 @@ -from vertexai.generative_models import GenerationConfig - - -class StepModelGenerationConfig: - def __init__( - self, - temperature: float, - top_p: float, - top_k: int, - candidate_count: int, - max_output_tokens: int, - stop_sequences: list[str], - ): - self.temperature = temperature - self.top_p = top_p - self.top_k = top_k - self.candidate_count = candidate_count - self.max_output_tokens = max_output_tokens - self.stop_sequences = stop_sequences - - def to_generation_config(self) -> GenerationConfig: - return GenerationConfig( - temperature=self.temperature, - top_p=self.top_p, - top_k=self.top_k, - candidate_count=self.candidate_count, - max_output_tokens=self.max_output_tokens, - stop_sequences=self.stop_sequences, - ) - - -class StepModelConfig: - - def __init__( - self, model: str, generation_config: StepModelGenerationConfig | None = None - ): - self.model = model - self.generation_config = generation_config +from falkordb_gemini_kg.models import GenerativeModel class KnowledgeGraphModelConfig: def __init__( self, - extract_data: StepModelConfig | None = None, - cypher_generation: StepModelConfig | None = None, - qa: StepModelConfig | None = None, + extract_data: GenerativeModel, + cypher_generation: GenerativeModel, + qa: GenerativeModel, ): self.extract_data = extract_data self.cypher_generation = cypher_generation self.qa = qa @staticmethod - def from_dict(d: dict): - model = d.get("model") - generation_config = d.get("generation_config") - extract_data = StepModelConfig(model=model, generation_config=generation_config) - cypher_generation = StepModelConfig( - model=model, generation_config=generation_config - ) - qa = StepModelConfig(model=model, generation_config=generation_config) + def with_model(model: GenerativeModel): return KnowledgeGraphModelConfig( - extract_data=extract_data, - cypher_generation=cypher_generation, - qa=qa, + extract_data=model, + cypher_generation=model, + qa=model, ) diff --git a/falkordb_gemini_kg/classes/node.py b/falkordb_gemini_kg/classes/node.py index 737211d..46d383e 100644 --- a/falkordb_gemini_kg/classes/node.py +++ b/falkordb_gemini_kg/classes/node.py @@ -1,6 +1,6 @@ import json import logging -from .attribute import Attribute, _AttributeType +from .attribute import Attribute, AttributeType from falkordb import Node as GraphNode logger = logging.getLogger(__name__) @@ -18,7 +18,7 @@ def from_graph(node: GraphNode): [ Attribute( attr, - _AttributeType.fromString(node.properties[attr]), + AttributeType.fromString(node.properties[attr]), "!" in node.properties[attr], ) for attr in node.properties diff --git a/falkordb_gemini_kg/classes/ontology.py b/falkordb_gemini_kg/classes/ontology.py index a90e141..1935e8b 100644 --- a/falkordb_gemini_kg/classes/ontology.py +++ b/falkordb_gemini_kg/classes/ontology.py @@ -1,7 +1,7 @@ import json from falkordb import Graph from falkordb_gemini_kg.classes.source import AbstractSource -from falkordb_gemini_kg.classes.model_config import StepModelConfig +from falkordb_gemini_kg.models import GenerativeModel import falkordb_gemini_kg import logging from .edge import Edge @@ -19,12 +19,12 @@ def __init__(self, nodes: list[Node] = [], edges: list[Edge] = []): def from_sources( sources: list[AbstractSource], boundaries: str, - model_config: StepModelConfig, + model: GenerativeModel, ) -> "Ontology": step = falkordb_gemini_kg.CreateOntologyStep( sources=sources, ontology=Ontology(), - model_config=model_config, + model=model, ) return step.run(boundaries=boundaries) diff --git a/falkordb_gemini_kg/helpers.py b/falkordb_gemini_kg/helpers.py index ac9d31c..8a172ae 100644 --- a/falkordb_gemini_kg/helpers.py +++ b/falkordb_gemini_kg/helpers.py @@ -165,9 +165,6 @@ def validate_cypher_edge_directions(cypher: str, ontology: Ontology): if ontology_edge is None: errors.append(f"Edge {edge_label} not found in ontology") - logger.debug( - f"ontology_edge: {ontology_edge}" - ) if ( not ontology_edge.source.label == source_label or not ontology_edge.target.label == target_label diff --git a/falkordb_gemini_kg/kg.py b/falkordb_gemini_kg/kg.py index ff56742..e57cb12 100644 --- a/falkordb_gemini_kg/kg.py +++ b/falkordb_gemini_kg/kg.py @@ -104,7 +104,7 @@ def _create_graph_with_sources(self, sources: list[AbstractSource] | None = None step = ExtractDataStep( sources=list(sources), ontology=self.ontology, - model_config=self._model_config.extract_data, + model=self._model_config.extract_data, graph=self.graph, ) @@ -124,16 +124,18 @@ def ask(self, question: str) -> str: >>> ans = kg.ask("List a few movies in which that actored played in", history) """ + cypher_chat_session = self._model_config.cypher_generation.start_chat() cypher_step = GraphQueryGenerationStep( ontology=self.ontology, - model_config=self._model_config.cypher_generation, + chat_session=cypher_chat_session, graph=self.graph, ) (context, cypher) = cypher_step.run(question) + qa_chat_session = self._model_config.qa.start_chat() qa_step = QAStep( - model_config=self._model_config.qa, + chat_session=qa_chat_session, ) answer = qa_step.run(question, cypher, context) diff --git a/falkordb_gemini_kg/models/__init__.py b/falkordb_gemini_kg/models/__init__.py new file mode 100644 index 0000000..87cb736 --- /dev/null +++ b/falkordb_gemini_kg/models/__init__.py @@ -0,0 +1 @@ +from .model import * diff --git a/falkordb_gemini_kg/models/gemini.py b/falkordb_gemini_kg/models/gemini.py new file mode 100644 index 0000000..08e6dee --- /dev/null +++ b/falkordb_gemini_kg/models/gemini.py @@ -0,0 +1,89 @@ +from .model import * +from vertexai.generative_models import ( + GenerativeModel as VertexAiGenerativeModel, + GenerationConfig as VertexAiGenerationConfig, + GenerationResponse as VertexAiGenerationResponse, + FinishReason as VertexAiFinishReason, +) + + +class GeminiGenerativeModel(GenerativeModel): + + _model: VertexAiGenerativeModel = None + + def __init__( + self, + model_name: str, + generation_config: GenerativeModelConfig | None = None, + system_instruction: str | None = None, + ): + self._model_name = model_name + self._generation_config = generation_config + self._system_instruction = system_instruction + + def _get_model(self) -> VertexAiGenerativeModel: + if self._model is None: + self._model = VertexAiGenerativeModel( + self._model_name, + generation_config=( + VertexAiGenerationConfig( + temperature=self._generation_config.temperature, + top_p=self._generation_config.top_p, + top_k=self._generation_config.top_k, + max_output_tokens=self._generation_config.max_output_tokens, + stop_sequences=self._generation_config.stop_sequences, + ) + if self._generation_config is not None + else None + ), + system_instruction=self._system_instruction, + ) + + return self._model + + def with_system_instruction(self, system_instruction: str) -> "GenerativeModel": + self._system_instruction = system_instruction + self._model = None + self._get_model() + + return self + + def start_chat(self, args: dict | None = None) -> GenerativeModelChatSession: + return GeminiChatSession(self, args) + + def ask(self, message: str) -> GenerationResponse: + response = self._model.generate_content(message) + return self.parse_generate_content_response(response) + + def parse_generate_content_response( + self, response: VertexAiGenerationResponse + ) -> GenerationResponse: + return GenerationResponse( + text=response.text, + finish_reason=( + FinishReason.MAX_TOKENS + if response.candidates[0].finish_reason + == VertexAiFinishReason.MAX_TOKENS + else ( + FinishReason.STOP + if response.candidates[0].finish_reason == VertexAiFinishReason.STOP + else FinishReason.OTHER + ) + ), + ) + + +class GeminiChatSession(GenerativeModelChatSession): + + def __init__(self, model: GeminiGenerativeModel, args: dict | None = None): + self._model = model + self._chat_session = self._model._model.start_chat( + history=args.get("history", []) if args is not None else [], + response_validation=( + args.get("response_validation", False) if args is not None else True + ), + ) + + def send_message(self, message: str) -> GenerationResponse: + response = self._chat_session.send_message(message) + return self._model.parse_generate_content_response(response) diff --git a/falkordb_gemini_kg/models/model.py b/falkordb_gemini_kg/models/model.py new file mode 100644 index 0000000..847a4da --- /dev/null +++ b/falkordb_gemini_kg/models/model.py @@ -0,0 +1,56 @@ +from abc import ABC, abstractmethod + + +class FinishReason: + MAX_TOKENS = "MAX_TOKENS" + STOP = "STOP" + OTHER = "OTHER" + + +class GenerativeModelConfig: + def __init__( + self, + temperature: float, + top_p: float, + top_k: int, + max_output_tokens: int, + stop_sequences: list[str], + ): + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.max_output_tokens = max_output_tokens + self.stop_sequences = stop_sequences + + +class GenerationResponse: + + def __init__(self, text: str, finish_reason: FinishReason): + self.text = text + self.finish_reason = finish_reason + + +class GenerativeModelChatSession(ABC): + + @abstractmethod + def __init__(self, model: "GenerativeModel"): + self.model = model + + @abstractmethod + def send_message(self, message: str) -> GenerationResponse: + pass + + +class GenerativeModel(ABC): + + @abstractmethod + def with_system_instruction(self, system_instruction: str) -> "GenerativeModel": + pass + + @abstractmethod + def start_chat(self, args: dict | None) -> GenerativeModelChatSession: + pass + + @abstractmethod + def ask(self, message: str) -> GenerationResponse: + pass diff --git a/falkordb_gemini_kg/steps/create_ontology_step.py b/falkordb_gemini_kg/steps/create_ontology_step.py index 4e9ecdd..8c0b62e 100644 --- a/falkordb_gemini_kg/steps/create_ontology_step.py +++ b/falkordb_gemini_kg/steps/create_ontology_step.py @@ -3,14 +3,6 @@ from falkordb_gemini_kg.classes.Document import Document from concurrent.futures import Future, ThreadPoolExecutor, wait from falkordb_gemini_kg.classes.ontology import Ontology -from falkordb_gemini_kg.classes.model_config import StepModelConfig -from vertexai.generative_models import ( - GenerativeModel, - ChatSession, - ResponseValidationError, - GenerationResponse, - FinishReason, -) from falkordb_gemini_kg.fixtures.prompts import ( CREATE_ONTOLOGY_SYSTEM, CREATE_ONTOLOGY_PROMPT, @@ -20,6 +12,13 @@ from falkordb_gemini_kg.helpers import extract_json from ratelimit import limits, sleep_and_retry import time +from falkordb_gemini_kg.models import ( + GenerativeModel, + GenerativeModelChatSession, + GenerativeModelConfig, + GenerationResponse, + FinishReason, +) logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -34,7 +33,7 @@ def __init__( self, sources: list[AbstractSource], ontology: Ontology, - model_config: StepModelConfig, + model: GenerativeModel, config: dict = { "max_workers": 16, "max_input_tokens": 500000, @@ -43,19 +42,11 @@ def __init__( ) -> None: self.sources = sources self.ontology = ontology - self.model_config = model_config + self.model = model.with_system_instruction(CREATE_ONTOLOGY_SYSTEM) self.config = config def _create_chat(self): - return GenerativeModel( - self.model_config.model, - generation_config=( - self.model_config.generation_config.to_generation_config() - if self.model_config.generation_config is not None - else None - ), - system_instruction=CREATE_ONTOLOGY_SYSTEM, - ).start_chat(response_validation=False) + return self.model.start_chat({"response_validation": False}) def run(self, boundaries: str): tasks: list[Future[Ontology]] = [] @@ -90,7 +81,7 @@ def run(self, boundaries: str): def _process_source( self, - chat_session: ChatSession, + chat_session: GenerativeModelChatSession, document: Document, o: Ontology, boundaries: str, @@ -107,15 +98,15 @@ def _process_source( logger.debug(f"Model response: {responses[response_idx].text}") while ( - responses[response_idx].candidates[0].finish_reason + responses[response_idx].finish_reason == FinishReason.MAX_TOKENS ): response_idx += 1 responses.append(self._call_model(chat_session, "continue")) - if responses[response_idx].candidates[0].finish_reason != FinishReason.STOP: + if responses[response_idx].finish_reason != FinishReason.STOP: raise Exception( - f"Model stopped unexpectedly: {responses[response_idx].candidates[0].finish_reason}" + f"Model stopped unexpectedly: {responses[response_idx].finish_reason}" ) combined_text = " ".join([r.text for r in responses]) @@ -133,7 +124,7 @@ def _process_source( return o - def _fix_ontology(self, chat_session: ChatSession, o: Ontology): + def _fix_ontology(self, chat_session: GenerativeModelChatSession, o: Ontology): logger.debug(f"Fixing ontology...") user_message = FIX_ONTOLOGY_PROMPT.format(ontology=o) @@ -146,15 +137,15 @@ def _fix_ontology(self, chat_session: ChatSession, o: Ontology): logger.debug(f"Model response: {responses[response_idx]}") while ( - responses[response_idx].candidates[0].finish_reason + responses[response_idx].finish_reason == FinishReason.MAX_TOKENS ): response_idx += 1 responses.append(self._call_model(chat_session, "continue")) - if responses[response_idx].candidates[0].finish_reason != FinishReason.STOP: + if responses[response_idx].finish_reason != FinishReason.STOP: raise Exception( - f"Model stopped unexpectedly: {responses[response_idx].candidates[0].finish_reason}" + f"Model stopped unexpectedly: {responses[response_idx].finish_reason}" ) combined_text = " ".join([r.text for r in responses]) @@ -176,7 +167,7 @@ def _fix_ontology(self, chat_session: ChatSession, o: Ontology): @limits(calls=15, period=60) def _call_model( self, - chat_session: ChatSession, + chat_session: GenerativeModelChatSession, prompt: str, retry=6, ): diff --git a/falkordb_gemini_kg/steps/extract_data_step.py b/falkordb_gemini_kg/steps/extract_data_step.py index ed86f8e..935ff83 100644 --- a/falkordb_gemini_kg/steps/extract_data_step.py +++ b/falkordb_gemini_kg/steps/extract_data_step.py @@ -2,14 +2,14 @@ from falkordb_gemini_kg.classes.source import AbstractSource from concurrent.futures import Future, ThreadPoolExecutor, wait from falkordb_gemini_kg.classes.ontology import Ontology -from falkordb_gemini_kg.classes.model_config import StepModelConfig -from vertexai.generative_models import ( +from falkordb_gemini_kg.models import ( GenerativeModel, - ChatSession, - ResponseValidationError, + GenerativeModelChatSession, + GenerativeModelConfig, GenerationResponse, FinishReason, ) + from falkordb_gemini_kg.fixtures.prompts import ( EXTRACT_DATA_SYSTEM, EXTRACT_DATA_PROMPT, @@ -38,7 +38,7 @@ def __init__( self, sources: list[AbstractSource], ontology: Ontology, - model_config: StepModelConfig, + model: GenerativeModel, graph: Graph, config: dict = { "max_workers": 16, @@ -49,24 +49,16 @@ def __init__( self.sources = sources self.ontology = ontology self.config = config - self.model_config = model_config + self.model = model.with_system_instruction( + EXTRACT_DATA_SYSTEM.replace("#ONTOLOGY", str(self.ontology.to_json())) + ) self.graph = graph if not os.path.exists("logs"): os.makedirs("logs") def _create_chat(self): - return GenerativeModel( - self.model_config.model, - generation_config=( - self.model_config.generation_config.to_generation_config() - if self.model_config.generation_config is not None - else None - ), - system_instruction=EXTRACT_DATA_SYSTEM.replace( - "#ONTOLOGY", str(self.ontology.to_json()) - ), - ).start_chat(response_validation=False) + return self.model.start_chat({"response_validation": False}) def run(self): @@ -100,7 +92,7 @@ def run(self): def _process_source( self, task_id: str, - chat_session: ChatSession, + chat_session: GenerativeModelChatSession, document: Document, ontology: Ontology, graph: Graph, @@ -135,7 +127,7 @@ def _process_source( _task_logger.debug(f"Model response: {responses[response_idx]}") while ( - responses[response_idx].candidates[0].finish_reason + responses[response_idx].finish_reason == FinishReason.MAX_TOKENS ): _task_logger.debug("Asking model to continue") @@ -145,12 +137,12 @@ def _process_source( f"Model response after continue: {responses[response_idx].text}" ) - if responses[response_idx].candidates[0].finish_reason != FinishReason.STOP: + if responses[response_idx].finish_reason != FinishReason.STOP: _task_logger.debug( - f"Model stopped unexpectedly: {responses[response_idx].candidates[0].finish_reason}" + f"Model stopped unexpectedly: {responses[response_idx].finish_reason}" ) raise Exception( - f"Model stopped unexpectedly: {responses[response_idx].candidates[0].finish_reason}" + f"Model stopped unexpectedly: {responses[response_idx].finish_reason}" ) combined_text = " ".join([r.text for r in responses]) @@ -268,7 +260,7 @@ def _create_edge(self, graph: Graph, args: dict, ontology: Ontology): @limits(calls=15, period=60) def _call_model( self, - chat_session: ChatSession, + chat_session: GenerativeModelChatSession, prompt: str, retry=6, ): diff --git a/falkordb_gemini_kg/steps/graph_query_step.py b/falkordb_gemini_kg/steps/graph_query_step.py index c6e7e06..f0a9b64 100644 --- a/falkordb_gemini_kg/steps/graph_query_step.py +++ b/falkordb_gemini_kg/steps/graph_query_step.py @@ -1,7 +1,8 @@ from falkordb_gemini_kg.steps.Step import Step from falkordb_gemini_kg.classes.ontology import Ontology -from falkordb_gemini_kg.classes.model_config import StepModelConfig -from vertexai.generative_models import GenerativeModel +from falkordb_gemini_kg.models import ( + GenerativeModelChatSession, +) from falkordb_gemini_kg.fixtures.prompts import ( CYPHER_GEN_SYSTEM, CYPHER_GEN_PROMPT, @@ -28,30 +29,13 @@ def __init__( self, graph: Graph, ontology: Ontology, - model_config: StepModelConfig | None = None, + chat_session: GenerativeModelChatSession, config: dict = {}, - chat_session: GenerativeModel | None = None, ) -> None: - assert chat_session is not None or ( - model_config is not None - ), "Must provide either a chat session or model config" self.ontology = ontology self.config = config self.graph = graph - self.chat_session = ( - chat_session - or GenerativeModel( - model_config.model, - generation_config=( - model_config.generation_config.to_generation_config() - if model_config.generation_config is not None - else None - ), - system_instruction=CYPHER_GEN_SYSTEM.replace( - "#ONTOLOGY", str(ontology.to_json()) - ), - ).start_chat() - ) + self.chat_session = chat_session def run(self, question: str, retries: int = 5): error = False @@ -66,6 +50,7 @@ def run(self, question: str, retries: int = 5): question=question, error=error ) ) + logger.debug(f"Cypher Prompt: {cypher_prompt}") cypher_statement_response = self.chat_session.send_message( cypher_prompt, ) @@ -79,6 +64,7 @@ def run(self, question: str, retries: int = 5): if cypher is not None: result_set = self.graph.query(cypher).result_set context = stringify_falkordb_response(result_set) + logger.debug(f"Context: {context}") logger.debug(f"Context size: {len(result_set)}") logger.debug(f"Context characters: {len(str(context))}") @@ -87,3 +73,5 @@ def run(self, question: str, retries: int = 5): logger.debug(f"Error: {e}") error = e retries -= 1 + + raise Exception("Failed to generate Cypher query: " + str(error)) diff --git a/falkordb_gemini_kg/steps/qa_step.py b/falkordb_gemini_kg/steps/qa_step.py index 79653f5..4124bc3 100644 --- a/falkordb_gemini_kg/steps/qa_step.py +++ b/falkordb_gemini_kg/steps/qa_step.py @@ -1,8 +1,6 @@ from falkordb_gemini_kg.steps.Step import Step -from falkordb_gemini_kg.classes.Document import Document -from falkordb_gemini_kg.classes.ontology import Ontology -from falkordb_gemini_kg.classes.model_config import StepModelConfig -from vertexai.generative_models import GenerativeModel, ChatSession +from falkordb_gemini_kg.models import GenerativeModelChatSession + from falkordb_gemini_kg.fixtures.prompts import GRAPH_QA_SYSTEM, GRAPH_QA_PROMPT import logging @@ -17,26 +15,11 @@ class QAStep(Step): def __init__( self, - model_config: StepModelConfig | None = None, + chat_session: GenerativeModelChatSession, config: dict = {}, - chat_session: ChatSession | None = None, ) -> None: - assert chat_session is not None or ( - model_config is not None - ), "Must provide either a chat session or model config" self.config = config - self.chat_session = ( - chat_session - or GenerativeModel( - model_config.model, - generation_config=( - model_config.generation_config.to_generation_config() - if model_config.generation_config is not None - else None - ), - system_instruction=GRAPH_QA_SYSTEM, - ).start_chat() - ) + self.chat_session = chat_session def run(self, question: str, cypher: str, context: str): diff --git a/poetry.lock b/poetry.lock index ba28e99..087a5ec 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1205,6 +1205,16 @@ pluggy = ">=1.5,<2.0" [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "python-abc" +version = "0.2.0" +description = "A python implementation of the ABC Software metric" +optional = false +python-versions = "*" +files = [ + {file = "python-abc-0.2.0.tar.gz", hash = "sha256:90017d09fbac7bde4b64b2c7e1b5d22da9055b64b821d1a2b4dc805b450b251a"}, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -1219,6 +1229,20 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "python-dotenv" +version = "1.0.1" +description = "Read key-value pairs from a .env file and set them as environment variables" +optional = false +python-versions = ">=3.8" +files = [ + {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, + {file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"}, +] + +[package.extras] +cli = ["click (>=5.0)"] + [[package]] name = "pytz" version = "2024.1" @@ -1230,6 +1254,16 @@ files = [ {file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"}, ] +[[package]] +name = "ratelimit" +version = "2.2.1" +description = "API rate limit decorator" +optional = false +python-versions = "*" +files = [ + {file = "ratelimit-2.2.1.tar.gz", hash = "sha256:af8a9b64b821529aca09ebaf6d8d279100d766f19e90b5059ac6a718ca6dee42"}, +] + [[package]] name = "redis" version = "5.0.6" @@ -1697,4 +1731,4 @@ xai = ["tensorflow (>=2.3.0,<3.0.0dev)"] [metadata] lock-version = "2.0" python-versions = "^3.11.4" -content-hash = "942341bb3e408c5eacf030255f70888c1f26c7f7e1f644b8dbefbd60cf5e2164" +content-hash = "ce0d4be4a2614e6cda568a2af1f0068532528e20f01163963cccb7f0f63851b4" diff --git a/pyproject.toml b/pyproject.toml index a3d00b0..025975c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,9 @@ sqlalchemy = "^2.0.30" pypdf = "^4.2.0" vertexai = "^1.49.0" backoff = "^2.2.1" +python-abc = "^0.2.0" +ratelimit = "^2.2.1" +python-dotenv = "^1.0.1" [tool.poetry.group.test.dependencies] pytest = "^8.2.1" @@ -33,6 +36,6 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] log_cli = true -log_cli_level = "DEBUG" +log_cli_level = "INFO" log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" log_cli_date_format = "%Y-%m-%d %H:%M:%S" diff --git a/tests/data/madoff.txt b/tests/data/madoff.txt new file mode 100644 index 0000000..b419c76 --- /dev/null +++ b/tests/data/madoff.txt @@ -0,0 +1,249 @@ + + + +Episode guide +4 +Cast & crew +User reviews +FAQ +IMDbPro + +Madoff: The Monster of Wall Street +TV Mini Series +2023 +15 +1h 2m +IMDb RATING +YOUR RATING +Madoff: The Monster of Wall Street (2023) + +It follows the rise and fall of the American financier and ponzi schemer: Madoff. +Play trailer with sound1:50 +It follows the rise and fall of the American financier and ponzi schemer: Madoff. + +Stars +Joseph ScottoMelony FelicianoDonna Pastorello +See production info at IMDbPro +31 +User reviews +16 +Critic reviews +Episodes +4 +TOP-RATED +Wed, Jan 4, 2023 +S1.E4 +The Price of Trust +Madoff's scheme is exposed amid the 2008 financial market crisis. His victims' lives are upended as they trace years of obstacles to recoup their losses. +7.7 +/10 +TOP-RATED +Wed, Jan 4, 2023 +S1.E3 +See No Evil +Competitors investigate Madoff's impossible numbers and alert the Securities and Exchange Commission, but the agency shrugs off multiple red flags. +7.5 +/10 +BROWSE EPISODES +Videos +2 +Official Trailer +Trailer 1:50 +Watch Official Trailer +Madoff: The Monster Of Wall Street +Trailer 1:42 +Watch Madoff: The Monster Of Wall Street +Photos +12 +Madoff: The Monster of Wall Street (2023) +Ginger OToole and Joseph Scotto in Netflix docuseries Madoff: The Monster of Wall Street +Netflix docuseries Madoff: The Monster of Wall Street - Young Bernie Madoff +Netflix docuseries Madoff: The Monster of Wall Street - Young Bernie Madoff +Donna Pastorello and Eleanor Squillari in Madoff: The Monster of Wall Street (2023) +Donna Pastorello in Madoff: The Monster of Wall Street (2023) +Joseph Scotto, Donna Pastorello, Sarah Kuklis, Isa Camyar, Alex Olson, Kevin Delano, Cris Colicchio, and Ashley Rose Folino in Madoff: The Monster of Wall Street (2023) +Joseph Scotto and Cris Colicchio in Madoff: The Monster of Wall Street (2023) +Madoff: The Monster of Wall Street (2023) +Madoff: The Monster of Wall Street (2023) +Madoff: The Monster of Wall Street (2023) +Top cast +Joseph Scotto +Joseph Scotto +Bernie Madoff +Melony Feliciano +Melony Feliciano +Background Extra +Donna Pastorello +Donna Pastorello +Eleanor Squillari +Isa Camyar +Isa Camyar +Frank DiPascali +Sarah Kuklis +Sarah Kuklis +Ellen Hales +Alex Olson +Alex Olson +Mark Madoff +Elijah George +Elijah George +19th Floor Trader +Howie Schaal +Howie Schaal +Jerry O'Hara +Stephanie Beauchamp +Jodi Crupi +Cris Colicchio +Cris Colicchio +Peter Madoff +Alex Hammerli +Madoff Employee +Alicia Erlinger +Annette Bongiorno +Diana B. Henriques +Self - Author - The Wizard of Lies… +Kevin Delano +Andrew Madoff +Robert Loftus +19th Floor Trader +Paul Faggione +Paul Faggione +Jeffrey Tucker +Marla Freeman +Marla Freeman +Sonja Kohn +Rafael Antonio Vasquez +Rafael Antonio Vasquez +George Perez +All cast & crew +Production, box office & more at IMDbPro +More like this +Murdaugh Murders: A Southern Scandal +6.8 +Murdaugh Murders: A Southern Scandal + +Waco: American Apocalypse +7.0 +Waco: American Apocalypse + +Pepsi, Where's My Jet? +7.0 +Pepsi, Where's My Jet? + +The Hatchet Wielding Hitchhiker +6.2 +The Hatchet Wielding Hitchhiker + +American Manhunt: The Boston Marathon Bombing +7.5 +American Manhunt: The Boston Marathon Bombing + +MH370: The Plane That Disappeared +6.1 +MH370: The Plane That Disappeared + +Eat the Rich: The GameStop Saga +6.3 +Eat the Rich: The GameStop Saga + +Madoff +7.4 +Madoff + +FIFA Uncovered +7.4 +FIFA Uncovered + +Killer Sally +6.7 +Killer Sally + +Jeffrey Epstein: Filthy Rich +7.1 +Jeffrey Epstein: Filthy Rich + +Pamela: A Love Story +7.2 +Pamela: A Love Story + +Storyline +Did you know +Connections +Featured in Jeremy Vine: Episode #6.5 (2023) +User reviews +31 +FEATURED REVIEW +6 +/10 +Good to know but way too slow. +A very interesting documentary on a topic that I think is good to know for most people. But, in typical Netflix fashion, is dragged out over way too many episodes and cost you way too much of your time. + +The lack of original footage has been made up by some acting, but all scenes seem endlessly repeated. Good they interviewed a lot of involved people and they got a hand on some original content though. But it feels like it could fit in a 1 hour documentary movie. So yes this is a great topic, but I'd rather advice to read a news article on it than to spend your hours of precious time on Netflix. + +helpful +• +16 + +3 + +owen89Jan 10, 2023 +Top picks +Sign in to rate and Watchlist for personalized recommendations +FAQ +14 +How many seasons does Madoff: The Monster of Wall Street have? +Powered by Alexa +Details +Release date +January 4, 2023 (United Kingdom) +Country of origin +United States +Official site +Netflix Site +Language +English +Also known as +МЕЙДОФФ: Монстр із Волл-стріт +Production companies +RadicalMediaThird Eye Motion Picture Company +See more company credits at IMDbPro +Technical specs +Runtime +1 hour 2 minutes +Color +Color +Sound mix +Dolby Digital +Aspect ratio +16:9 HD +Related news +Contribute to this page +Suggest an edit or add missing content +IMDb Answers: Help fill gaps in our data +Learn more about contributing +More to explore +Production art +Photos +Hollywood Power Couples +See the gallery +Production art +Photos +The Greatest Character Actors of All Time +See the gallery +Poster +List +5 Movies to Watch While Gearing Up for 'Furiosa' +See our picks +Recently viewed +Madoff: The Monster of Wall Street (2023) +Madoff: The Monster of Wall Street +WeWork (2021) +WeWork +Follow IMDb on social +Get the IMDb app +For Android and iOS +Get the IMDb app +HelpSite IndexIMDbProBox Office MojoLicense IMDb Data +Press RoomAdvertisingJobsConditions of UsePrivacy Policy +Your Ads Privacy Choices +© 1990-2024 by IMDb.com, Inc. diff --git a/tests/test_auto_create_ontology.py b/tests/test_auto_create_ontology.py new file mode 100644 index 0000000..ead6e17 --- /dev/null +++ b/tests/test_auto_create_ontology.py @@ -0,0 +1,34 @@ +from dotenv import load_dotenv +load_dotenv() +from falkordb_gemini_kg.classes.ontology import Ontology +import unittest +from falkordb_gemini_kg.classes.source import Source +from falkordb_gemini_kg.models.gemini import GeminiGenerativeModel +import vertexai +import os +import logging + +logging.basicConfig(level=logging.DEBUG) + +vertexai.init(project=os.getenv("PROJECT_ID"), location=os.getenv("REGION")) + + +class TestAutoDetectOntology(unittest.TestCase): + """ + Test auto-detect ontology + """ + + def test_auto_detect_ontology(self): + + file_path = "tests/data/madoff.txt" + + sources = [Source(file_path)] + + model = GeminiGenerativeModel(model_name="gemini-1.5-flash-001") + + boundaries = """ + Extract entities and relationships from each page +""" + ontology = Ontology.from_sources(sources, boundaries=boundaries, model=model) + + logging.info(f"Ontology: {ontology.to_json()}") diff --git a/tests/test_kg.py b/tests/test_kg.py new file mode 100644 index 0000000..4b01894 --- /dev/null +++ b/tests/test_kg.py @@ -0,0 +1,102 @@ +from dotenv import load_dotenv + +load_dotenv() +from falkordb_gemini_kg.classes.ontology import Ontology +from falkordb_gemini_kg.classes.node import Node +from falkordb_gemini_kg.classes.edge import Edge +from falkordb_gemini_kg.classes.attribute import Attribute, AttributeType +import unittest +from falkordb_gemini_kg.classes.source import Source +from falkordb_gemini_kg.models.gemini import GeminiGenerativeModel +from falkordb_gemini_kg import KnowledgeGraph, KnowledgeGraphModelConfig +import vertexai +import os +import logging +from falkordb import FalkorDB + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +vertexai.init(project=os.getenv("PROJECT_ID"), location=os.getenv("REGION")) + + +class TestKG(unittest.TestCase): + """ + Test Knowledge Graph + """ + + @classmethod + def setUpClass(cls): + + cls.ontology = Ontology([], []) + + cls.ontology.add_node( + Node( + label="Actor", + attributes=[ + Attribute( + name="name", + attr_type=AttributeType.STRING, + unique=True, + required=True, + ), + ], + ) + ) + cls.ontology.add_node( + Node( + label="Movie", + attributes=[ + Attribute( + name="title", + attr_type=AttributeType.STRING, + unique=True, + required=True, + ), + ], + ) + ) + cls.ontology.add_edge( + Edge( + label="ACTED_IN", + source="Actor", + target="Movie", + attributes=[ + Attribute( + name="role", + attr_type=AttributeType.STRING, + unique=False, + required=False, + ), + ], + ) + ) + + model = GeminiGenerativeModel(model_name="gemini-1.5-flash-001") + cls.kg = KnowledgeGraph( + name="IMDB", + ontology=cls.ontology, + model_config=KnowledgeGraphModelConfig.with_model(model), + ) + + def test_kg_creation(self): + + file_path = "tests/data/madoff.txt" + + sources = [Source(file_path)] + + self.kg.process_sources(sources) + + answer = self.kg.ask("List a few actors") + + logger.info(f"Answer: {answer}") + + assert "Joseph Scotto" in answer, "Joseph Scotto not found in answer" + + def test_kg_delete(self): + + self.kg.delete() + + db = FalkorDB() + graphs = db.list_graphs() + self.assertNotIn("IMDB", graphs)