Skip to content
This repository has been archived by the owner on Sep 19, 2024. It is now read-only.

Commit

Permalink
add cypher edge validation use case
Browse files Browse the repository at this point in the history
  • Loading branch information
dudizimber committed Jul 10, 2024
1 parent be43a3d commit 8e47ceb
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 9 deletions.
2 changes: 1 addition & 1 deletion falkordb_gemini_kg/classes/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
self.required = required

@staticmethod
def from_json(txt: str):
def from_json(txt: str | dict):
txt = txt if isinstance(txt, dict) else json.loads(txt)
if txt["type"] not in [
AttributeType.STRING,
Expand Down
18 changes: 13 additions & 5 deletions falkordb_gemini_kg/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@
logger = logging.getLogger(__name__)


def extract_json(text: str):
def extract_json(text: str | dict) -> str:
if not isinstance(text, str):
text = str(text)
regex = r"(?:```)?(?:json)?([^`]*)(?:\\n)?(?:```)?"
matches = re.findall(regex, text, re.DOTALL)

return repair_json("".join(matches))
try:
return repair_json("".join(matches))
except Exception as e:
logger.error(f"Failed to repair JSON: {e}")
return "".join(matches)


def map_dict_to_cypher_properties(d: dict):
Expand Down Expand Up @@ -143,15 +149,17 @@ def validate_cypher_edge_directions(cypher: str, ontology: falkordb_gemini_kg.On
if prev_edge
else cypher[: edge.start()]
)
rel_before = re.search(r"([^\)\]]+)", before[::-1]).group(0)[::-1]
if "," in before:
before = before.split(",")[-1]
rel_before = re.search(r"([^\)\],]+)", before[::-1]).group(0)[::-1]
after = (
cypher[edge.end() : next_edge.start()]
if next_edge
else cypher[edge.end() :]
)
rel_after = re.search(r"([^\(\[]+)", after).group(0)
rel_after = re.search(r"([^\(\[,]+)", after).group(0)
node_before = re.search(r"\(.+:(.*?)\)", before).group(0)
node_after = re.search(r"\(([^\)]+)(\)?)", after).group(0)
node_after = re.search(r"\(([^\),]+)(\)?)", after).group(0)
if rel_before == "-" and rel_after == "->":
source = node_before
target = node_after
Expand Down
90 changes: 87 additions & 3 deletions tests/test_helper_validate_cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class TestValidateCypher1(unittest.TestCase):
"""
Test a valid cypher query
Test a valid cypher query
"""

cypher = """
Expand Down Expand Up @@ -82,10 +82,9 @@ def test_validate_cypher(self):

class TestValidateCypher2(unittest.TestCase):
"""
Test a cypher query with the wrong edge direction
Test a cypher query with the wrong edge direction
"""


cypher = """
MATCH (f:Fighter)<-[r:FOUGHT_IN]-(fight:Fight)
RETURN f"""
Expand Down Expand Up @@ -142,5 +141,90 @@ def test_validate_cypher(self):
assert errors is not None


class TestValidateCypher3(unittest.TestCase):
"""
Test a cypher query with multiple right edge directions
"""

cypher = """
MATCH (a:Airline)-[:ACCEPTS]->(p:Pet), (r:Route)-[:ALLOWS]->(sd:Service_Dog)
RETURN a, p, r, sd
"""

@classmethod
def setUpClass(cls):

cls._ontology = Ontology([], [])

cls._ontology.add_node(
Node(
label="Airline",
attributes=[],
)
)

cls._ontology.add_node(
Node(
label="Pet",
attributes=[],
)
)

cls._ontology.add_node(
Node(
label="Route",
attributes=[],
)
)

cls._ontology.add_node(
Node(
label="Service_Dog",
attributes=[],
)
)

cls._ontology.add_edge(
Edge(
label="ACCEPTS",
source="Airline",
target="Pet",
attributes=[],
)
)

cls._ontology.add_edge(
Edge(
label="ALLOWS",
source="Route",
target="Service_Dog",
attributes=[],
)
)

def test_validate_cypher_nodes_exist(self):

errors = validate_cypher_nodes_exist(self.cypher, self._ontology)

assert len(errors) == 0

def test_validate_cypher_edges_exist(self):

errors = validate_cypher_edges_exist(self.cypher, self._ontology)

assert len(errors) == 0

def test_validate_cypher_edge_directions(self):

errors = validate_cypher_edge_directions(self.cypher, self._ontology)

assert len(errors) == 0

def test_validate_cypher(self):
errors = validate_cypher(self.cypher, self._ontology)

assert errors is None or len(errors) == 0


if __name__ == "__main__":
unittest.main()

0 comments on commit 8e47ceb

Please sign in to comment.