diff --git a/falkordb_gemini_kg/classes/ChatSession.py b/falkordb_gemini_kg/classes/ChatSession.py index 84446b9..b681660 100644 --- a/falkordb_gemini_kg/classes/ChatSession.py +++ b/falkordb_gemini_kg/classes/ChatSession.py @@ -46,6 +46,9 @@ def send_message(self, message: str): (context, cypher) = cypher_step.run(message) + if not cypher or len(cypher) == 0: + return "I am sorry, I could not find the answer to your question" + qa_step = QAStep( chat_session=self.qa_chat_session, ) diff --git a/falkordb_gemini_kg/classes/attribute.py b/falkordb_gemini_kg/classes/attribute.py index 5b75059..8fcee78 100644 --- a/falkordb_gemini_kg/classes/attribute.py +++ b/falkordb_gemini_kg/classes/attribute.py @@ -1,6 +1,7 @@ import json from falkordb_gemini_kg.fixtures.regex import * import logging +import re logger = logging.getLogger(__name__) @@ -23,7 +24,7 @@ class Attribute: def __init__( self, name: str, attr_type: AttributeType, unique: bool, required: bool = False ): - self.name = name + self.name = re.sub(r"([^a-zA-Z0-9_])", "_", name) self.type = attr_type self.unique = unique self.required = required @@ -70,4 +71,3 @@ def to_json(self): def __str__(self) -> str: return f"{self.name}: \"{self.type}{'!' if self.unique else ''}{'*' if self.required else ''}\"" - diff --git a/falkordb_gemini_kg/classes/edge.py b/falkordb_gemini_kg/classes/edge.py index d92f2e1..3c69252 100644 --- a/falkordb_gemini_kg/classes/edge.py +++ b/falkordb_gemini_kg/classes/edge.py @@ -7,14 +7,15 @@ logger = logging.getLogger(__name__) + class _EdgeNode: def __init__(self, label: str): - self.label = label + self.label = re.sub(r"([^a-zA-Z0-9_])", "", label) @staticmethod def from_json(txt: str): txt = txt if isinstance(txt, dict) else json.loads(txt) - return _EdgeNode(txt["label"]) + return _EdgeNode(txt["label"] if "label" in txt else txt) def to_json(self): return {"label": self.label} @@ -28,22 +29,21 @@ def __init__( self, label: str, source: _EdgeNode | str, - target: _EdgeNode | str, + target: _EdgeNode | str, attributes: list[Attribute], ): - + if isinstance(source, str): source = _EdgeNode(source) if isinstance(target, str): target = _EdgeNode(target) - + assert isinstance(label, str), "Label must be a string" assert isinstance(source, _EdgeNode), "Source must be an EdgeNode" assert isinstance(target, _EdgeNode), "Target must be an EdgeNode" assert isinstance(attributes, list), "Attributes must be a list" - - self.label = label + self.label = re.sub(r"([^a-zA-Z0-9_])", "", label.upper()) self.source = source self.target = target self.attributes = attributes diff --git a/falkordb_gemini_kg/classes/node.py b/falkordb_gemini_kg/classes/node.py index 46d383e..ae4b278 100644 --- a/falkordb_gemini_kg/classes/node.py +++ b/falkordb_gemini_kg/classes/node.py @@ -2,13 +2,16 @@ import logging from .attribute import Attribute, AttributeType from falkordb import Node as GraphNode - +import re logger = logging.getLogger(__name__) +descriptionKey = "__description__" + class Node: - def __init__(self, label: str, attributes: list[Attribute]): - self.label = label + def __init__(self, label: str, attributes: list[Attribute], description: str = ""): + self.label = re.sub(r"([^a-zA-Z0-9_])", "", label) self.attributes = attributes + self.description = description @staticmethod def from_graph(node: GraphNode): @@ -21,22 +24,28 @@ def from_graph(node: GraphNode): AttributeType.fromString(node.properties[attr]), "!" in node.properties[attr], ) - for attr in node.properties + for attr in node.properties if attr != descriptionKey ], + node.properties[descriptionKey] if descriptionKey in node.properties else "", ) @staticmethod def from_json(txt: dict | str): txt = txt if isinstance(txt, dict) else json.loads(txt) return Node( - txt["label"].replace(" ", ""), - [Attribute.from_json(attr) for attr in txt["attributes"]], + txt["label"], + [ + Attribute.from_json(attr) + for attr in (txt["attributes"] if "attributes" in txt else []) + ], + txt["description"] if "description" in txt else "", ) def to_json(self): return { "label": self.label, "attributes": [attr.to_json() for attr in self.attributes], + "description": self.description, } def combine(self, node2: "Node"): @@ -55,10 +64,12 @@ def get_unique_attributes(self): return [attr for attr in self.attributes if attr.unique] def to_graph_query(self): - return f"MERGE (n:{self.label} {{{', '.join([str(attr) for attr in self.attributes])}}}) RETURN n" + attributes = ", ".join([str(attr) for attr in self.attributes]) + if self.description: + attributes += f"{', ' if len(attributes) > 0 else ''} {descriptionKey}: '{self.description}'" + return f"MERGE (n:{self.label} {{{attributes}}}) RETURN n" def __str__(self) -> str: return ( f"(:{self.label} {{{', '.join([str(attr) for attr in self.attributes])}}})" ) - diff --git a/falkordb_gemini_kg/classes/ontology.py b/falkordb_gemini_kg/classes/ontology.py index 1935e8b..e37980f 100644 --- a/falkordb_gemini_kg/classes/ontology.py +++ b/falkordb_gemini_kg/classes/ontology.py @@ -146,8 +146,8 @@ def validate_nodes(self): def get_node_with_label(self, label: str): return next((n for n in self.nodes if n.label == label), None) - def get_edge_with_label(self, label: str): - return next((e for e in self.edges if e.label == label), None) + def get_edges_with_label(self, label: str): + return [e for e in self.edges if e.label == label] def has_node_with_label(self, label: str): return any(n.label == label for n in self.nodes) diff --git a/falkordb_gemini_kg/document_loaders/url.py b/falkordb_gemini_kg/document_loaders/url.py index 97955db..cce4c33 100644 --- a/falkordb_gemini_kg/document_loaders/url.py +++ b/falkordb_gemini_kg/document_loaders/url.py @@ -21,7 +21,7 @@ def __init__(self, url: str) -> None: def _download(self) -> str: try: - response = requests.get(self.url) + response = requests.get(self.url, headers={'User-Agent': 'Mozilla/5.0'}) response.raise_for_status() # Raise an HTTPError for bad responses (4xx and 5xx) return response.text except requests.exceptions.RequestException as e: diff --git a/falkordb_gemini_kg/fixtures/prompts.py b/falkordb_gemini_kg/fixtures/prompts.py index fbef3c3..9e18219 100644 --- a/falkordb_gemini_kg/fixtures/prompts.py +++ b/falkordb_gemini_kg/fixtures/prompts.py @@ -9,6 +9,7 @@ Add as many attributes to nodes and edges as necessary to fully describe the entities and relationships in the text. Prefer to convert edges into nodes when they have attributes. For example, if an edge represents a relationship with attributes, convert it into a node with the attributes as properties. Create a very concise and clear ontology. Avoid unnecessary complexity and ambiguity in the ontology. +Node and edge labels cannot start with numbers or special characters. ## 2. Labeling Nodes - **Consistency**: Ensure you use available types for node labels. Ensure you use basic or elementary types for node labels. For example, when you identify an entity representing a person, always label it as **'person'**. Avoid using more specific terms "like 'mathematician' or 'scientist'" @@ -338,6 +339,9 @@ All formats should be consistent, for example, dates should be in the format "YYYY-MM-DD". If needed, add the correct spacing for text fields, where the text is not properly formatted. +User instructions: +{instructions} + Raw Text: {text} """ @@ -410,7 +414,7 @@ Try to generate a new valid OpenCypher statement. Use only the provided nodes, relationships types and properties in the ontology. The output must be only a valid OpenCypher statement. -Do not include any text except the generated OpenCypher statement, enclosed in triple backticks. +Do not include any apologies or other texts, except the generated OpenCypher statement, enclosed in triple backticks. Question: {question} """ diff --git a/falkordb_gemini_kg/helpers.py b/falkordb_gemini_kg/helpers.py index 8a172ae..bd25675 100644 --- a/falkordb_gemini_kg/helpers.py +++ b/falkordb_gemini_kg/helpers.py @@ -1,14 +1,16 @@ import re -from falkordb_gemini_kg.classes.ontology import Ontology +import falkordb_gemini_kg import logging +from fix_busted_json import repair_json logger = logging.getLogger(__name__) + def extract_json(text: str): regex = r"(?:```)?(?:json)?([^`]*)(?:\\n)?(?:```)?" matches = re.findall(regex, text, re.DOTALL) - return "".join(matches) + return repair_json("".join(matches)) def map_dict_to_cypher_properties(d: dict): @@ -69,10 +71,12 @@ def extract_cypher(text: str): return "".join(matches) -def validate_cypher(cypher: str, ontology: Ontology) -> list[str] | None: +def validate_cypher( + cypher: str, ontology: falkordb_gemini_kg.Ontology +) -> list[str] | None: try: if not cypher or len(cypher) == 0: - return "Cypher statement is empty" + return ["Cypher statement is empty"] errors = [] @@ -94,7 +98,7 @@ def validate_cypher(cypher: str, ontology: Ontology) -> list[str] | None: return None -def validate_cypher_nodes_exist(cypher: str, ontology: Ontology): +def validate_cypher_nodes_exist(cypher: str, ontology: falkordb_gemini_kg.Ontology): # Check if nodes exist in ontology not_found_node_labels = [] node_labels = re.findall(r"\(:(.*?)\)", cypher) @@ -107,7 +111,7 @@ def validate_cypher_nodes_exist(cypher: str, ontology: Ontology): return [f"Node {label} not found in ontology" for label in not_found_node_labels] -def validate_cypher_edges_exist(cypher: str, ontology: Ontology): +def validate_cypher_edges_exist(cypher: str, ontology: falkordb_gemini_kg.Ontology): # Check if edges exist in ontology not_found_edge_labels = [] edge_labels = re.findall(r"\[:(.*?)\]", cypher) @@ -120,7 +124,7 @@ def validate_cypher_edges_exist(cypher: str, ontology: Ontology): return [f"Edge {label} not found in ontology" for label in not_found_edge_labels] -def validate_cypher_edge_directions(cypher: str, ontology: Ontology): +def validate_cypher_edge_directions(cypher: str, ontology: falkordb_gemini_kg.Ontology): errors = [] edges = list(re.finditer(r"\[.*?\]", cypher)) @@ -160,22 +164,37 @@ def validate_cypher_edge_directions(cypher: str, ontology: Ontology): source_label = re.search(r"(?:\:)([^\)\{]+)", source).group(1).strip() target_label = re.search(r"(?:\:)([^\)\{]+)", target).group(1).strip() - ontology_edge = ontology.get_edge_with_label(edge_label) + ontology_edges = ontology.get_edges_with_label(edge_label) - if ontology_edge is None: + if len(ontology_edges) == 0: errors.append(f"Edge {edge_label} not found in ontology") - if ( - not ontology_edge.source.label == source_label - or not ontology_edge.target.label == target_label - ): + found_edge = False + for ontology_edge in ontology_edges: + if ( + ontology_edge.source.label == source_label + and ontology_edge.target.label == target_label + ): + found_edge = True + break + + if not found_edge: errors.append( - f"Edge {edge_label} has a mismatched source or target. Make sure the edge direction is correct. The edge should connect {ontology_edge.source.label} to {ontology_edge.target.label}." + """ + Edge {edge_label} does not connect {source_label} to {target_label}. Make sure the edge direction is correct. + Valid edges: + {valid_edges} +""".format( + edge_label=edge_label, + source_label=source_label, + target_label=target_label, + valid_edges="\n".join([str(e) for e in ontology_edges]), + ) ) i += 1 except Exception as e: - errors.append(str(e)) + # errors.append(str(e)) continue return errors diff --git a/falkordb_gemini_kg/kg.py b/falkordb_gemini_kg/kg.py index 5209bd1..1e94744 100644 --- a/falkordb_gemini_kg/kg.py +++ b/falkordb_gemini_kg/kg.py @@ -82,7 +82,7 @@ def list_sources(self) -> list[AbstractSource]: return [s.source for s in self.sources] - def process_sources(self, sources: list[AbstractSource]) -> None: + def process_sources(self, sources: list[AbstractSource], instructions: str = None) -> None: """ Add entities and relations found in sources into the knowledge-graph @@ -94,13 +94,13 @@ def process_sources(self, sources: list[AbstractSource]) -> None: raise Exception("Ontology is not defined") # Create graph with sources - self._create_graph_with_sources(sources) + self._create_graph_with_sources(sources, instructions) # Add processed sources for src in sources: self.sources.add(src) - def _create_graph_with_sources(self, sources: list[AbstractSource] | None = None): + def _create_graph_with_sources(self, sources: list[AbstractSource] | None = None, instructions: str = None): step = ExtractDataStep( sources=list(sources), @@ -109,7 +109,7 @@ def _create_graph_with_sources(self, sources: list[AbstractSource] | None = None graph=self.graph, ) - step.run() + step.run(instructions) def ask(self, question: str) -> str: """ @@ -138,6 +138,9 @@ def ask(self, question: str) -> str: (context, cypher) = cypher_step.run(question) + if not cypher or len(cypher) == 0: + return "I am sorry, I could not find the answer to your question" + qa_chat_session = self._model_config.qa.with_system_instruction( GRAPH_QA_SYSTEM ).start_chat() diff --git a/falkordb_gemini_kg/models/model.py b/falkordb_gemini_kg/models/model.py index 847a4da..fabe089 100644 --- a/falkordb_gemini_kg/models/model.py +++ b/falkordb_gemini_kg/models/model.py @@ -29,6 +29,9 @@ def __init__(self, text: str, finish_reason: FinishReason): self.text = text self.finish_reason = finish_reason + def __str__(self) -> str: + return f"GenerationResponse(text={self.text}, finish_reason={self.finish_reason})" + class GenerativeModelChatSession(ABC): diff --git a/falkordb_gemini_kg/models/openai.py b/falkordb_gemini_kg/models/openai.py index 8b4650c..5317ebb 100644 --- a/falkordb_gemini_kg/models/openai.py +++ b/falkordb_gemini_kg/models/openai.py @@ -37,7 +37,7 @@ def ask(self, message: str) -> GenerationResponse: model=self.model_name, messages=[ {"role": "system", "content": self.system_instruction}, - {"role": "user", "content": message}, + {"role": "user", "content": message[:14385]}, ], max_tokens=self.generation_config.max_output_tokens, temperature=self.generation_config.temperature, @@ -78,7 +78,7 @@ def __init__(self, model: OpenAiGenerativeModel, args: dict | None = None): def send_message(self, message: str) -> GenerationResponse: prompt = [] prompt.extend(self._history) - prompt.append({"role": "user", "content": message}) + prompt.append({"role": "user", "content": message[:14385]}) response = self._model.client.chat.completions.create( model=self._model.model_name, messages=prompt, diff --git a/falkordb_gemini_kg/steps/create_ontology_step.py b/falkordb_gemini_kg/steps/create_ontology_step.py index 8c0b62e..fb8cbb2 100644 --- a/falkordb_gemini_kg/steps/create_ontology_step.py +++ b/falkordb_gemini_kg/steps/create_ontology_step.py @@ -7,6 +7,7 @@ CREATE_ONTOLOGY_SYSTEM, CREATE_ONTOLOGY_PROMPT, FIX_ONTOLOGY_PROMPT, + FIX_JSON_PROMPT, ) import logging from falkordb_gemini_kg.helpers import extract_json @@ -19,6 +20,7 @@ GenerationResponse, FinishReason, ) +import json logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -95,12 +97,9 @@ def _process_source( responses.append(self._call_model(chat_session, user_message)) - logger.debug(f"Model response: {responses[response_idx].text}") + logger.debug(f"Model response: {responses[response_idx]}") - while ( - responses[response_idx].finish_reason - == FinishReason.MAX_TOKENS - ): + while responses[response_idx].finish_reason == FinishReason.MAX_TOKENS: response_idx += 1 responses.append(self._call_model(chat_session, "continue")) @@ -112,9 +111,28 @@ def _process_source( combined_text = " ".join([r.text for r in responses]) try: - new_ontology = Ontology.from_json(extract_json(combined_text)) + data = json.loads(extract_json(combined_text)) + except json.decoder.JSONDecodeError as e: + logger.debug(f"Error extracting JSON: {e}") + logger.debug(f"Prompting model to fix JSON") + json_fix_response = self._call_model( + self._create_chat(), + FIX_JSON_PROMPT.format(json=combined_text, error=str(e)), + ) + try: + data = json.loads(extract_json(json_fix_response.text)) + logger.debug(f"Fixed JSON: {data}") + except json.decoder.JSONDecodeError as e: + logger.error(f"Failed to fix JSON: {e} {json_fix_response.text}") + data = None + + if data is None: + return o + + try: + new_ontology = Ontology.from_json(data) except Exception as e: - logger.debug(f"Exception while extracting JSON: {e}") + logger.error(f"Exception while extracting JSON: {e}") new_ontology = None if new_ontology is not None: @@ -136,10 +154,7 @@ def _fix_ontology(self, chat_session: GenerativeModelChatSession, o: Ontology): logger.debug(f"Model response: {responses[response_idx]}") - while ( - responses[response_idx].finish_reason - == FinishReason.MAX_TOKENS - ): + while responses[response_idx].finish_reason == FinishReason.MAX_TOKENS: response_idx += 1 responses.append(self._call_model(chat_session, "continue")) @@ -151,9 +166,28 @@ def _fix_ontology(self, chat_session: GenerativeModelChatSession, o: Ontology): combined_text = " ".join([r.text for r in responses]) try: - new_ontology = Ontology.from_json(extract_json(combined_text)) + data = json.loads(extract_json(combined_text)) + except json.decoder.JSONDecodeError as e: + logger.debug(f"Error extracting JSON: {e}") + logger.debug(f"Prompting model to fix JSON") + json_fix_response = self._call_model( + self._create_chat(), + FIX_JSON_PROMPT.format(json=combined_text, error=str(e)), + ) + try: + data = json.loads(extract_json(json_fix_response.text)) + logger.debug(f"Fixed JSON: {data}") + except json.decoder.JSONDecodeError as e: + logger.error(f"Failed to fix JSON: {e} {json_fix_response.text}") + data = None + + if data is None: + return o + + try: + new_ontology = Ontology.from_json(data) except Exception as e: - print(f"Exception while extracting JSON: {e}") + logger.debug(f"Exception while extracting JSON: {e}") new_ontology = None if new_ontology is not None: diff --git a/falkordb_gemini_kg/steps/extract_data_step.py b/falkordb_gemini_kg/steps/extract_data_step.py index 935ff83..8f3cfd1 100644 --- a/falkordb_gemini_kg/steps/extract_data_step.py +++ b/falkordb_gemini_kg/steps/extract_data_step.py @@ -60,13 +60,13 @@ def __init__( def _create_chat(self): return self.model.start_chat({"response_validation": False}) - def run(self): + def run(self, instructions: str = None): tasks: list[Future[Ontology]] = [] with ThreadPoolExecutor(max_workers=self.config["max_workers"]) as executor: # extract entities and relationships from each page documents = [ - document + (document, source.instruction) for source in self.sources for document in source.load() if document is not None @@ -74,7 +74,7 @@ def run(self): and len(document.content) > 0 ] logger.debug(f"Processing {len(documents)} documents") - for document in documents: + for document, source_instructions in documents: task_id = "extract_data_step_" + str(uuid4()) task = executor.submit( self._process_source, @@ -83,6 +83,8 @@ def run(self): document, self.ontology, self.graph, + source_instructions, + instructions, ) tasks.append(task) @@ -96,6 +98,8 @@ def _process_source( document: Document, ontology: Ontology, graph: Graph, + source_instructions: str = "", + instructions: str = "", ): try: _task_logger = logging.getLogger(task_id) @@ -114,7 +118,15 @@ def _process_source( logger.debug(f"Processing task: {task_id}") _task_logger.debug(f"Processing task: {task_id}") text = document.content[: self.config["max_input_tokens"]] - user_message = EXTRACT_DATA_PROMPT.format(text=text) + user_message = EXTRACT_DATA_PROMPT.format( + text=text, + instructions="\n".join( + [ + source_instructions if source_instructions is not None else "", + instructions if instructions is not None else "", + ] + ), + ) # logger.debug(f"User message: {user_message}") _task_logger.debug("User message: " + user_message.replace("\n", " ")) @@ -124,12 +136,9 @@ def _process_source( responses.append(self._call_model(chat_session, user_message)) - _task_logger.debug(f"Model response: {responses[response_idx]}") + _task_logger.debug(f"Model response: {responses[response_idx].text}") - while ( - responses[response_idx].finish_reason - == FinishReason.MAX_TOKENS - ): + while responses[response_idx].finish_reason == FinishReason.MAX_TOKENS: _task_logger.debug("Asking model to continue") response_idx += 1 responses.append(self._call_model(chat_session, "continue")) @@ -149,7 +158,7 @@ def _process_source( try: data = json.loads(extract_json(combined_text)) - except json.decoder.JSONDecodeError as e: + except Exception as e: _task_logger.debug(f"Error extracting JSON: {e}") _task_logger.debug(f"Prompting model to fix JSON") json_fix_response = self._call_model( @@ -160,22 +169,26 @@ def _process_source( _task_logger.debug(f"Fixed JSON: {data}") if not "nodes" in data or not "edges" in data: - _task_logger.debug(f"Invalid data format: {data}") - raise Exception(f"Invalid data format: {data}") - + _task_logger.debug( + f"Invalid data format. Missing nodes or edges. {data}" + ) + raise Exception( + f"Invalid data format. Missing 'nodes' or 'edges' in JSON." + ) for node in data["nodes"]: try: self._create_node(graph, node, ontology) except Exception as e: - logger.exception(e) + _task_logger.error(f"Error creating node: {e}") continue for edge in data["edges"]: try: self._create_edge(graph, edge, ontology) except Exception as e: - logger.exception(e) + _task_logger.error(f"Error creating edge: {e}") continue + except Exception as e: logger.exception(e) raise e @@ -213,9 +226,9 @@ def _create_node(self, graph: Graph, args: dict, ontology: Ontology): return result def _create_edge(self, graph: Graph, args: dict, ontology: Ontology): - edge = ontology.get_edge_with_label(args["label"]) - if edge is None: - print(f"Edge with label {args['label']} not found in ontology") + edges = ontology.get_edges_with_label(args["label"]) + if len(edges) == 0: + print(f"Edges with label {args['label']} not found in ontology") return None source_unique_attributes = ( args["source"]["attributes"] diff --git a/falkordb_gemini_kg/steps/graph_query_step.py b/falkordb_gemini_kg/steps/graph_query_step.py index f0a9b64..cfd2b15 100644 --- a/falkordb_gemini_kg/steps/graph_query_step.py +++ b/falkordb_gemini_kg/steps/graph_query_step.py @@ -54,8 +54,13 @@ def run(self, question: str, retries: int = 5): cypher_statement_response = self.chat_session.send_message( cypher_prompt, ) + logger.debug(f"Cypher Statement Response: {cypher_statement_response}") cypher = extract_cypher(cypher_statement_response.text) logger.debug(f"Cypher: {cypher}") + + if not cypher or len(cypher) == 0: + return (None, None) + validation_errors = validate_cypher(cypher, self.ontology) # print(f"Is valid: {is_valid}") if validation_errors is not None: diff --git a/poetry.lock b/poetry.lock index c7772f6..b21dc17 100644 --- a/poetry.lock +++ b/poetry.lock @@ -280,6 +280,17 @@ files = [ [package.dependencies] redis = ">=5.0.1,<6.0.0" +[[package]] +name = "fix-busted-json" +version = "0.0.18" +description = "Fixes broken JSON string objects" +optional = false +python-versions = ">=3.6" +files = [ + {file = "fix-busted-json-0.0.18.tar.gz", hash = "sha256:93c5dab7cae3b5d0b055f2c7043f9fe727a88a80d0be753c5f2c20bb9b69672f"}, + {file = "fix_busted_json-0.0.18-py3-none-any.whl", hash = "sha256:fdce0e02c9a810b3aa28e1c3c32c24b21b44e89f6315ec25d2b963bd52a6ef03"}, +] + [[package]] name = "google-api-core" version = "2.19.0" @@ -1872,4 +1883,4 @@ xai = ["tensorflow (>=2.3.0,<3.0.0dev)"] [metadata] lock-version = "2.0" python-versions = "^3.11.4" -content-hash = "f4afb02d8a5f1042e2c2c170e4de0e4a25303a52e60b8396ca7dd87be1588a48" +content-hash = "36e08a4df78f4d028fd6959b6a953096d9470c3f9b68d70b79b2b16c56c5aa5a" diff --git a/pyproject.toml b/pyproject.toml index 6b83308..255b59e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ python-abc = "^0.2.0" ratelimit = "^2.2.1" python-dotenv = "^1.0.1" openai = "^1.35.9" +fix-busted-json = "^0.0.18" [tool.poetry.group.test.dependencies] pytest = "^8.2.1" diff --git a/tests/test_kg_gemini.py b/tests/test_kg_gemini.py index 86ef4c5..7c11945 100644 --- a/tests/test_kg_gemini.py +++ b/tests/test_kg_gemini.py @@ -72,9 +72,11 @@ def setUpClass(cls): ) ) + cls.graph_name = "IMDB_gemini" + model = GeminiGenerativeModel(model_name="gemini-1.5-flash-001") cls.kg = KnowledgeGraph( - name="IMDB", + name=cls.graph_name, ontology=cls.ontology, model_config=KnowledgeGraphModelConfig.with_model(model), ) @@ -99,4 +101,4 @@ def test_kg_delete(self): db = FalkorDB() graphs = db.list_graphs() - self.assertNotIn("IMDB", graphs) + self.assertNotIn(self.graph_name, graphs) diff --git a/tests/test_kg_openai.py b/tests/test_kg_openai.py index 53aab32..2f67468 100644 --- a/tests/test_kg_openai.py +++ b/tests/test_kg_openai.py @@ -69,10 +69,10 @@ def setUpClass(cls): ], ) ) - + cls.graph_name = "IMDB_openai" model = OpenAiGenerativeModel(model_name="gpt-3.5-turbo-0125") cls.kg = KnowledgeGraph( - name="IMDB", + name=cls.graph_name, ontology=cls.ontology, model_config=KnowledgeGraphModelConfig.with_model(model), ) @@ -97,4 +97,4 @@ def test_kg_delete(self): db = FalkorDB() graphs = db.list_graphs() - self.assertNotIn("IMDB", graphs) + self.assertNotIn(self.graph_name, graphs)