Skip to content
This repository has been archived by the owner on Sep 19, 2024. It is now read-only.

Commit

Permalink
Merge remote-tracking branch 'origin/4-support-openai' into 3-support…
Browse files Browse the repository at this point in the history
…-multi-agent-architecture
  • Loading branch information
dudizimber committed Jul 10, 2024
2 parents d2ba29f + 8e47ceb commit d99261d
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 7 deletions.
2 changes: 1 addition & 1 deletion falkordb_gemini_kg/classes/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
self.required = required

@staticmethod
def from_json(txt: str):
def from_json(txt: str | dict):
txt = txt if isinstance(txt, dict) else json.loads(txt)
if txt["type"] not in [
AttributeType.STRING,
Expand Down
14 changes: 10 additions & 4 deletions falkordb_gemini_kg/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,19 @@
logger = logging.getLogger(__name__)


def extract_json(text: str):
def extract_json(text: str | dict) -> str:
if not isinstance(text, str):
text = str(text)
regex = r"(?:```)?(?:json)?([^`]*)(?:\\n)?(?:```)?"
matches = re.findall(regex, text, re.DOTALL)

try:
return repair_json("".join(matches))
except Exception as e:
logger.error(f"Failed to repair JSON: {e}")
return "".join(matches)


def map_dict_to_cypher_properties(d: dict):
cypher = "{"
if isinstance(d, list):
Expand Down Expand Up @@ -145,15 +149,17 @@ def validate_cypher_relation_directions(cypher: str, ontology: falkordb_gemini_k
if prev_relation
else cypher[: relation.start()]
)
rel_before = re.search(r"([^\)\]]+)", before[::-1]).group(0)[::-1]
if "," in before:
before = before.split(",")[-1]
rel_before = re.search(r"([^\)\],]+)", before[::-1]).group(0)[::-1]
after = (
cypher[relation.end() : next_relation.start()]
if next_relation
else cypher[relation.end() :]
)
rel_after = re.search(r"([^\(\[]+)", after).group(0)
rel_after = re.search(r"([^\(\[,]+)", after).group(0)
entity_before = re.search(r"\(.+:(.*?)\)", before).group(0)
entity_after = re.search(r"\(([^\)]+)(\)?)", after).group(0)
entity_after = re.search(r"\(([^\),]+)(\)?)", after).group(0)
if rel_before == "-" and rel_after == "->":
source = entity_before
target = entity_after
Expand Down
88 changes: 86 additions & 2 deletions tests/test_helper_validate_cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class TestValidateCypher1(unittest.TestCase):
"""
Test a valid cypher query
Test a valid cypher query
"""

cypher = """
Expand Down Expand Up @@ -85,7 +85,6 @@ class TestValidateCypher2(unittest.TestCase):
Test a cypher query with the wrong relation direction
"""


cypher = """
MATCH (f:Fighter)<-[r:FOUGHT_IN]-(fight:Fight)
RETURN f"""
Expand Down Expand Up @@ -142,5 +141,90 @@ def test_validate_cypher(self):
assert errors is not None


class TestValidateCypher3(unittest.TestCase):
"""
Test a cypher query with multiple right edge directions
"""

cypher = """
MATCH (a:Airline)-[:ACCEPTS]->(p:Pet), (r:Route)-[:ALLOWS]->(sd:Service_Dog)
RETURN a, p, r, sd
"""

@classmethod
def setUpClass(cls):

cls._ontology = Ontology([], [])

cls._ontology.add_node(
Node(
label="Airline",
attributes=[],
)
)

cls._ontology.add_node(
Node(
label="Pet",
attributes=[],
)
)

cls._ontology.add_node(
Node(
label="Route",
attributes=[],
)
)

cls._ontology.add_node(
Node(
label="Service_Dog",
attributes=[],
)
)

cls._ontology.add_edge(
Edge(
label="ACCEPTS",
source="Airline",
target="Pet",
attributes=[],
)
)

cls._ontology.add_edge(
Edge(
label="ALLOWS",
source="Route",
target="Service_Dog",
attributes=[],
)
)

def test_validate_cypher_nodes_exist(self):

errors = validate_cypher_nodes_exist(self.cypher, self._ontology)

assert len(errors) == 0

def test_validate_cypher_edges_exist(self):

errors = validate_cypher_edges_exist(self.cypher, self._ontology)

assert len(errors) == 0

def test_validate_cypher_edge_directions(self):

errors = validate_cypher_edge_directions(self.cypher, self._ontology)

assert len(errors) == 0

def test_validate_cypher(self):
errors = validate_cypher(self.cypher, self._ontology)

assert errors is None or len(errors) == 0


if __name__ == "__main__":
unittest.main()

0 comments on commit d99261d

Please sign in to comment.