diff --git a/falkordb_gemini_kg/classes/attribute.py b/falkordb_gemini_kg/classes/attribute.py index 8fcee78..7e899fa 100644 --- a/falkordb_gemini_kg/classes/attribute.py +++ b/falkordb_gemini_kg/classes/attribute.py @@ -30,7 +30,7 @@ def __init__( self.required = required @staticmethod - def from_json(txt: str): + def from_json(txt: str | dict): txt = txt if isinstance(txt, dict) else json.loads(txt) if txt["type"] not in [ AttributeType.STRING, diff --git a/falkordb_gemini_kg/helpers.py b/falkordb_gemini_kg/helpers.py index 67f6226..14970af 100644 --- a/falkordb_gemini_kg/helpers.py +++ b/falkordb_gemini_kg/helpers.py @@ -6,15 +6,19 @@ logger = logging.getLogger(__name__) -def extract_json(text: str): +def extract_json(text: str | dict) -> str: + if not isinstance(text, str): + text = str(text) regex = r"(?:```)?(?:json)?([^`]*)(?:\\n)?(?:```)?" matches = re.findall(regex, text, re.DOTALL) try: return repair_json("".join(matches)) except Exception as e: + logger.error(f"Failed to repair JSON: {e}") return "".join(matches) + def map_dict_to_cypher_properties(d: dict): cypher = "{" if isinstance(d, list): @@ -145,15 +149,17 @@ def validate_cypher_relation_directions(cypher: str, ontology: falkordb_gemini_k if prev_relation else cypher[: relation.start()] ) - rel_before = re.search(r"([^\)\]]+)", before[::-1]).group(0)[::-1] + if "," in before: + before = before.split(",")[-1] + rel_before = re.search(r"([^\)\],]+)", before[::-1]).group(0)[::-1] after = ( cypher[relation.end() : next_relation.start()] if next_relation else cypher[relation.end() :] ) - rel_after = re.search(r"([^\(\[]+)", after).group(0) + rel_after = re.search(r"([^\(\[,]+)", after).group(0) entity_before = re.search(r"\(.+:(.*?)\)", before).group(0) - entity_after = re.search(r"\(([^\)]+)(\)?)", after).group(0) + entity_after = re.search(r"\(([^\),]+)(\)?)", after).group(0) if rel_before == "-" and rel_after == "->": source = entity_before target = entity_after diff --git a/tests/test_helper_validate_cypher.py b/tests/test_helper_validate_cypher.py index 0899f6c..83eef1e 100644 --- a/tests/test_helper_validate_cypher.py +++ b/tests/test_helper_validate_cypher.py @@ -19,7 +19,7 @@ class TestValidateCypher1(unittest.TestCase): """ - Test a valid cypher query + Test a valid cypher query """ cypher = """ @@ -85,7 +85,6 @@ class TestValidateCypher2(unittest.TestCase): Test a cypher query with the wrong relation direction """ - cypher = """ MATCH (f:Fighter)<-[r:FOUGHT_IN]-(fight:Fight) RETURN f""" @@ -142,5 +141,90 @@ def test_validate_cypher(self): assert errors is not None +class TestValidateCypher3(unittest.TestCase): + """ + Test a cypher query with multiple right edge directions + """ + + cypher = """ + MATCH (a:Airline)-[:ACCEPTS]->(p:Pet), (r:Route)-[:ALLOWS]->(sd:Service_Dog) + RETURN a, p, r, sd + """ + + @classmethod + def setUpClass(cls): + + cls._ontology = Ontology([], []) + + cls._ontology.add_node( + Node( + label="Airline", + attributes=[], + ) + ) + + cls._ontology.add_node( + Node( + label="Pet", + attributes=[], + ) + ) + + cls._ontology.add_node( + Node( + label="Route", + attributes=[], + ) + ) + + cls._ontology.add_node( + Node( + label="Service_Dog", + attributes=[], + ) + ) + + cls._ontology.add_edge( + Edge( + label="ACCEPTS", + source="Airline", + target="Pet", + attributes=[], + ) + ) + + cls._ontology.add_edge( + Edge( + label="ALLOWS", + source="Route", + target="Service_Dog", + attributes=[], + ) + ) + + def test_validate_cypher_nodes_exist(self): + + errors = validate_cypher_nodes_exist(self.cypher, self._ontology) + + assert len(errors) == 0 + + def test_validate_cypher_edges_exist(self): + + errors = validate_cypher_edges_exist(self.cypher, self._ontology) + + assert len(errors) == 0 + + def test_validate_cypher_edge_directions(self): + + errors = validate_cypher_edge_directions(self.cypher, self._ontology) + + assert len(errors) == 0 + + def test_validate_cypher(self): + errors = validate_cypher(self.cypher, self._ontology) + + assert errors is None or len(errors) == 0 + + if __name__ == "__main__": unittest.main()