From 8e47cebd3c0de93a146570da5409a4a55b11304a Mon Sep 17 00:00:00 2001 From: Dudi Zimberknopf Date: Wed, 10 Jul 2024 14:10:43 +0300 Subject: [PATCH] add cypher edge validation use case --- falkordb_gemini_kg/classes/attribute.py | 2 +- falkordb_gemini_kg/helpers.py | 18 +++-- tests/test_helper_validate_cypher.py | 90 ++++++++++++++++++++++++- 3 files changed, 101 insertions(+), 9 deletions(-) diff --git a/falkordb_gemini_kg/classes/attribute.py b/falkordb_gemini_kg/classes/attribute.py index 8fcee78..7e899fa 100644 --- a/falkordb_gemini_kg/classes/attribute.py +++ b/falkordb_gemini_kg/classes/attribute.py @@ -30,7 +30,7 @@ def __init__( self.required = required @staticmethod - def from_json(txt: str): + def from_json(txt: str | dict): txt = txt if isinstance(txt, dict) else json.loads(txt) if txt["type"] not in [ AttributeType.STRING, diff --git a/falkordb_gemini_kg/helpers.py b/falkordb_gemini_kg/helpers.py index bd25675..b4629f5 100644 --- a/falkordb_gemini_kg/helpers.py +++ b/falkordb_gemini_kg/helpers.py @@ -6,11 +6,17 @@ logger = logging.getLogger(__name__) -def extract_json(text: str): +def extract_json(text: str | dict) -> str: + if not isinstance(text, str): + text = str(text) regex = r"(?:```)?(?:json)?([^`]*)(?:\\n)?(?:```)?" matches = re.findall(regex, text, re.DOTALL) - return repair_json("".join(matches)) + try: + return repair_json("".join(matches)) + except Exception as e: + logger.error(f"Failed to repair JSON: {e}") + return "".join(matches) def map_dict_to_cypher_properties(d: dict): @@ -143,15 +149,17 @@ def validate_cypher_edge_directions(cypher: str, ontology: falkordb_gemini_kg.On if prev_edge else cypher[: edge.start()] ) - rel_before = re.search(r"([^\)\]]+)", before[::-1]).group(0)[::-1] + if "," in before: + before = before.split(",")[-1] + rel_before = re.search(r"([^\)\],]+)", before[::-1]).group(0)[::-1] after = ( cypher[edge.end() : next_edge.start()] if next_edge else cypher[edge.end() :] ) - rel_after = re.search(r"([^\(\[]+)", after).group(0) + rel_after = re.search(r"([^\(\[,]+)", after).group(0) node_before = re.search(r"\(.+:(.*?)\)", before).group(0) - node_after = re.search(r"\(([^\)]+)(\)?)", after).group(0) + node_after = re.search(r"\(([^\),]+)(\)?)", after).group(0) if rel_before == "-" and rel_after == "->": source = node_before target = node_after diff --git a/tests/test_helper_validate_cypher.py b/tests/test_helper_validate_cypher.py index e3f2c89..586f1c9 100644 --- a/tests/test_helper_validate_cypher.py +++ b/tests/test_helper_validate_cypher.py @@ -19,7 +19,7 @@ class TestValidateCypher1(unittest.TestCase): """ - Test a valid cypher query + Test a valid cypher query """ cypher = """ @@ -82,10 +82,9 @@ def test_validate_cypher(self): class TestValidateCypher2(unittest.TestCase): """ - Test a cypher query with the wrong edge direction + Test a cypher query with the wrong edge direction """ - cypher = """ MATCH (f:Fighter)<-[r:FOUGHT_IN]-(fight:Fight) RETURN f""" @@ -142,5 +141,90 @@ def test_validate_cypher(self): assert errors is not None +class TestValidateCypher3(unittest.TestCase): + """ + Test a cypher query with multiple right edge directions + """ + + cypher = """ + MATCH (a:Airline)-[:ACCEPTS]->(p:Pet), (r:Route)-[:ALLOWS]->(sd:Service_Dog) + RETURN a, p, r, sd + """ + + @classmethod + def setUpClass(cls): + + cls._ontology = Ontology([], []) + + cls._ontology.add_node( + Node( + label="Airline", + attributes=[], + ) + ) + + cls._ontology.add_node( + Node( + label="Pet", + attributes=[], + ) + ) + + cls._ontology.add_node( + Node( + label="Route", + attributes=[], + ) + ) + + cls._ontology.add_node( + Node( + label="Service_Dog", + attributes=[], + ) + ) + + cls._ontology.add_edge( + Edge( + label="ACCEPTS", + source="Airline", + target="Pet", + attributes=[], + ) + ) + + cls._ontology.add_edge( + Edge( + label="ALLOWS", + source="Route", + target="Service_Dog", + attributes=[], + ) + ) + + def test_validate_cypher_nodes_exist(self): + + errors = validate_cypher_nodes_exist(self.cypher, self._ontology) + + assert len(errors) == 0 + + def test_validate_cypher_edges_exist(self): + + errors = validate_cypher_edges_exist(self.cypher, self._ontology) + + assert len(errors) == 0 + + def test_validate_cypher_edge_directions(self): + + errors = validate_cypher_edge_directions(self.cypher, self._ontology) + + assert len(errors) == 0 + + def test_validate_cypher(self): + errors = validate_cypher(self.cypher, self._ontology) + + assert errors is None or len(errors) == 0 + + if __name__ == "__main__": unittest.main()