Skip to content
This repository has been archived by the owner on Sep 19, 2024. It is now read-only.

Commit

Permalink
Merge remote-tracking branch 'origin/4-support-openai' into 3-support…
Browse files Browse the repository at this point in the history
…-multi-agent-architecture
  • Loading branch information
dudizimber committed Jul 9, 2024
2 parents 5ab9eda + be43a3d commit 0c8ed3f
Show file tree
Hide file tree
Showing 18 changed files with 188 additions and 79 deletions.
3 changes: 3 additions & 0 deletions falkordb_gemini_kg/classes/ChatSession.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def send_message(self, message: str):

(context, cypher) = cypher_step.run(message)

if not cypher or len(cypher) == 0:
return "I am sorry, I could not find the answer to your question"

qa_step = QAStep(
chat_session=self.qa_chat_session,
)
Expand Down
4 changes: 2 additions & 2 deletions falkordb_gemini_kg/classes/attribute.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from falkordb_gemini_kg.fixtures.regex import *
import logging
import re

logger = logging.getLogger(__name__)

Expand All @@ -23,7 +24,7 @@ class Attribute:
def __init__(
self, name: str, attr_type: AttributeType, unique: bool, required: bool = False
):
self.name = name
self.name = re.sub(r"([^a-zA-Z0-9_])", "_", name)
self.type = attr_type
self.unique = unique
self.required = required
Expand Down Expand Up @@ -70,4 +71,3 @@ def to_json(self):

def __str__(self) -> str:
return f"{self.name}: \"{self.type}{'!' if self.unique else ''}{'*' if self.required else ''}\""

14 changes: 7 additions & 7 deletions falkordb_gemini_kg/classes/edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

logger = logging.getLogger(__name__)


class _EdgeNode:
def __init__(self, label: str):
self.label = label
self.label = re.sub(r"([^a-zA-Z0-9_])", "", label)

@staticmethod
def from_json(txt: str):
txt = txt if isinstance(txt, dict) else json.loads(txt)
return _EdgeNode(txt["label"])
return _EdgeNode(txt["label"] if "label" in txt else txt)

def to_json(self):
return {"label": self.label}
Expand All @@ -28,22 +29,21 @@ def __init__(
self,
label: str,
source: _EdgeNode | str,
target: _EdgeNode | str,
target: _EdgeNode | str,
attributes: list[Attribute],
):

if isinstance(source, str):
source = _EdgeNode(source)
if isinstance(target, str):
target = _EdgeNode(target)

assert isinstance(label, str), "Label must be a string"
assert isinstance(source, _EdgeNode), "Source must be an EdgeNode"
assert isinstance(target, _EdgeNode), "Target must be an EdgeNode"
assert isinstance(attributes, list), "Attributes must be a list"


self.label = label
self.label = re.sub(r"([^a-zA-Z0-9_])", "", label.upper())
self.source = source
self.target = target
self.attributes = attributes
Expand Down
27 changes: 19 additions & 8 deletions falkordb_gemini_kg/classes/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import logging
from .attribute import Attribute, AttributeType
from falkordb import Node as GraphNode

import re
logger = logging.getLogger(__name__)

descriptionKey = "__description__"

class Node:
def __init__(self, label: str, attributes: list[Attribute]):
self.label = label
def __init__(self, label: str, attributes: list[Attribute], description: str = ""):
self.label = re.sub(r"([^a-zA-Z0-9_])", "", label)
self.attributes = attributes
self.description = description

@staticmethod
def from_graph(node: GraphNode):
Expand All @@ -21,22 +24,28 @@ def from_graph(node: GraphNode):
AttributeType.fromString(node.properties[attr]),
"!" in node.properties[attr],
)
for attr in node.properties
for attr in node.properties if attr != descriptionKey
],
node.properties[descriptionKey] if descriptionKey in node.properties else "",
)

@staticmethod
def from_json(txt: dict | str):
txt = txt if isinstance(txt, dict) else json.loads(txt)
return Node(
txt["label"].replace(" ", ""),
[Attribute.from_json(attr) for attr in txt["attributes"]],
txt["label"],
[
Attribute.from_json(attr)
for attr in (txt["attributes"] if "attributes" in txt else [])
],
txt["description"] if "description" in txt else "",
)

def to_json(self):
return {
"label": self.label,
"attributes": [attr.to_json() for attr in self.attributes],
"description": self.description,
}

def combine(self, node2: "Node"):
Expand All @@ -55,10 +64,12 @@ def get_unique_attributes(self):
return [attr for attr in self.attributes if attr.unique]

def to_graph_query(self):
return f"MERGE (n:{self.label} {{{', '.join([str(attr) for attr in self.attributes])}}}) RETURN n"
attributes = ", ".join([str(attr) for attr in self.attributes])
if self.description:
attributes += f"{', ' if len(attributes) > 0 else ''} {descriptionKey}: '{self.description}'"
return f"MERGE (n:{self.label} {{{attributes}}}) RETURN n"

def __str__(self) -> str:
return (
f"(:{self.label} {{{', '.join([str(attr) for attr in self.attributes])}}})"
)

4 changes: 2 additions & 2 deletions falkordb_gemini_kg/classes/ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def validate_nodes(self):
def get_node_with_label(self, label: str):
return next((n for n in self.nodes if n.label == label), None)

def get_edge_with_label(self, label: str):
return next((e for e in self.edges if e.label == label), None)
def get_edges_with_label(self, label: str):
return [e for e in self.edges if e.label == label]

def has_node_with_label(self, label: str):
return any(n.label == label for n in self.nodes)
Expand Down
2 changes: 1 addition & 1 deletion falkordb_gemini_kg/document_loaders/url.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, url: str) -> None:

def _download(self) -> str:
try:
response = requests.get(self.url)
response = requests.get(self.url, headers={'User-Agent': 'Mozilla/5.0'})
response.raise_for_status() # Raise an HTTPError for bad responses (4xx and 5xx)
return response.text
except requests.exceptions.RequestException as e:
Expand Down
6 changes: 5 additions & 1 deletion falkordb_gemini_kg/fixtures/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Add as many attributes to nodes and edges as necessary to fully describe the entities and relationships in the text.
Prefer to convert edges into nodes when they have attributes. For example, if an edge represents a relationship with attributes, convert it into a node with the attributes as properties.
Create a very concise and clear ontology. Avoid unnecessary complexity and ambiguity in the ontology.
Node and edge labels cannot start with numbers or special characters.
## 2. Labeling Nodes
- **Consistency**: Ensure you use available types for node labels. Ensure you use basic or elementary types for node labels. For example, when you identify an entity representing a person, always label it as **'person'**. Avoid using more specific terms "like 'mathematician' or 'scientist'"
Expand Down Expand Up @@ -338,6 +339,9 @@
All formats should be consistent, for example, dates should be in the format "YYYY-MM-DD".
If needed, add the correct spacing for text fields, where the text is not properly formatted.
User instructions:
{instructions}
Raw Text:
{text}
"""
Expand Down Expand Up @@ -410,7 +414,7 @@
Try to generate a new valid OpenCypher statement.
Use only the provided nodes, relationships types and properties in the ontology.
The output must be only a valid OpenCypher statement.
Do not include any text except the generated OpenCypher statement, enclosed in triple backticks.
Do not include any apologies or other texts, except the generated OpenCypher statement, enclosed in triple backticks.
Question: {question}
"""
Expand Down
49 changes: 34 additions & 15 deletions falkordb_gemini_kg/helpers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import re
from falkordb_gemini_kg.classes.ontology import Ontology
import falkordb_gemini_kg
import logging
from fix_busted_json import repair_json

logger = logging.getLogger(__name__)


def extract_json(text: str):
regex = r"(?:```)?(?:json)?([^`]*)(?:\\n)?(?:```)?"
matches = re.findall(regex, text, re.DOTALL)

return "".join(matches)
return repair_json("".join(matches))


def map_dict_to_cypher_properties(d: dict):
Expand Down Expand Up @@ -69,10 +71,12 @@ def extract_cypher(text: str):
return "".join(matches)


def validate_cypher(cypher: str, ontology: Ontology) -> list[str] | None:
def validate_cypher(
cypher: str, ontology: falkordb_gemini_kg.Ontology
) -> list[str] | None:
try:
if not cypher or len(cypher) == 0:
return "Cypher statement is empty"
return ["Cypher statement is empty"]

errors = []

Expand All @@ -94,7 +98,7 @@ def validate_cypher(cypher: str, ontology: Ontology) -> list[str] | None:
return None


def validate_cypher_nodes_exist(cypher: str, ontology: Ontology):
def validate_cypher_nodes_exist(cypher: str, ontology: falkordb_gemini_kg.Ontology):
# Check if nodes exist in ontology
not_found_node_labels = []
node_labels = re.findall(r"\(:(.*?)\)", cypher)
Expand All @@ -107,7 +111,7 @@ def validate_cypher_nodes_exist(cypher: str, ontology: Ontology):
return [f"Node {label} not found in ontology" for label in not_found_node_labels]


def validate_cypher_edges_exist(cypher: str, ontology: Ontology):
def validate_cypher_edges_exist(cypher: str, ontology: falkordb_gemini_kg.Ontology):
# Check if edges exist in ontology
not_found_edge_labels = []
edge_labels = re.findall(r"\[:(.*?)\]", cypher)
Expand All @@ -120,7 +124,7 @@ def validate_cypher_edges_exist(cypher: str, ontology: Ontology):
return [f"Edge {label} not found in ontology" for label in not_found_edge_labels]


def validate_cypher_edge_directions(cypher: str, ontology: Ontology):
def validate_cypher_edge_directions(cypher: str, ontology: falkordb_gemini_kg.Ontology):

errors = []
edges = list(re.finditer(r"\[.*?\]", cypher))
Expand Down Expand Up @@ -160,22 +164,37 @@ def validate_cypher_edge_directions(cypher: str, ontology: Ontology):
source_label = re.search(r"(?:\:)([^\)\{]+)", source).group(1).strip()
target_label = re.search(r"(?:\:)([^\)\{]+)", target).group(1).strip()

ontology_edge = ontology.get_edge_with_label(edge_label)
ontology_edges = ontology.get_edges_with_label(edge_label)

if ontology_edge is None:
if len(ontology_edges) == 0:
errors.append(f"Edge {edge_label} not found in ontology")

if (
not ontology_edge.source.label == source_label
or not ontology_edge.target.label == target_label
):
found_edge = False
for ontology_edge in ontology_edges:
if (
ontology_edge.source.label == source_label
and ontology_edge.target.label == target_label
):
found_edge = True
break

if not found_edge:
errors.append(
f"Edge {edge_label} has a mismatched source or target. Make sure the edge direction is correct. The edge should connect {ontology_edge.source.label} to {ontology_edge.target.label}."
"""
Edge {edge_label} does not connect {source_label} to {target_label}. Make sure the edge direction is correct.
Valid edges:
{valid_edges}
""".format(
edge_label=edge_label,
source_label=source_label,
target_label=target_label,
valid_edges="\n".join([str(e) for e in ontology_edges]),
)
)

i += 1
except Exception as e:
errors.append(str(e))
# errors.append(str(e))
continue

return errors
11 changes: 7 additions & 4 deletions falkordb_gemini_kg/kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def list_sources(self) -> list[AbstractSource]:

return [s.source for s in self.sources]

def process_sources(self, sources: list[AbstractSource]) -> None:
def process_sources(self, sources: list[AbstractSource], instructions: str = None) -> None:
"""
Add entities and relations found in sources into the knowledge-graph
Expand All @@ -94,13 +94,13 @@ def process_sources(self, sources: list[AbstractSource]) -> None:
raise Exception("Ontology is not defined")

# Create graph with sources
self._create_graph_with_sources(sources)
self._create_graph_with_sources(sources, instructions)

# Add processed sources
for src in sources:
self.sources.add(src)

def _create_graph_with_sources(self, sources: list[AbstractSource] | None = None):
def _create_graph_with_sources(self, sources: list[AbstractSource] | None = None, instructions: str = None):

step = ExtractDataStep(
sources=list(sources),
Expand All @@ -109,7 +109,7 @@ def _create_graph_with_sources(self, sources: list[AbstractSource] | None = None
graph=self.graph,
)

step.run()
step.run(instructions)

def ask(self, question: str) -> str:
"""
Expand Down Expand Up @@ -138,6 +138,9 @@ def ask(self, question: str) -> str:

(context, cypher) = cypher_step.run(question)

if not cypher or len(cypher) == 0:
return "I am sorry, I could not find the answer to your question"

qa_chat_session = self._model_config.qa.with_system_instruction(
GRAPH_QA_SYSTEM
).start_chat()
Expand Down
3 changes: 3 additions & 0 deletions falkordb_gemini_kg/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def __init__(self, text: str, finish_reason: FinishReason):
self.text = text
self.finish_reason = finish_reason

def __str__(self) -> str:
return f"GenerationResponse(text={self.text}, finish_reason={self.finish_reason})"


class GenerativeModelChatSession(ABC):

Expand Down
4 changes: 2 additions & 2 deletions falkordb_gemini_kg/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def ask(self, message: str) -> GenerationResponse:
model=self.model_name,
messages=[
{"role": "system", "content": self.system_instruction},
{"role": "user", "content": message},
{"role": "user", "content": message[:14385]},
],
max_tokens=self.generation_config.max_output_tokens,
temperature=self.generation_config.temperature,
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(self, model: OpenAiGenerativeModel, args: dict | None = None):
def send_message(self, message: str) -> GenerationResponse:
prompt = []
prompt.extend(self._history)
prompt.append({"role": "user", "content": message})
prompt.append({"role": "user", "content": message[:14385]})
response = self._model.client.chat.completions.create(
model=self._model.model_name,
messages=prompt,
Expand Down
Loading

0 comments on commit 0c8ed3f

Please sign in to comment.