Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Dorbmon committed Sep 17, 2024
1 parent 3753c7b commit 32e9ddd
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ jobs:
postgres-version: "14"
ssl: "on"
id: postgres
- name: Install pgvector
run: |
sudo /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y
sudo apt-get install postgresql-14-pgvector
- name: Build and Test
env:
POSTGRES_CONNECTION_STR: ${{ steps.postgres.outputs.connection-uri }}
Expand Down
129 changes: 129 additions & 0 deletions nano_graphrag/storage/neo4j.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from nano_graphrag._storage import BaseGraphStorage
from neo4j import AsyncGraphDatabase
import neo4j
from typing import Union
from nano_graphrag.graphrag import always_get_an_event_loop
from nano_graphrag.base import SingleCommunitySchema
import numpy as np

import nest_asyncio
nest_asyncio.apply()

class NetworkXStorage(BaseGraphStorage):
def __init__(self, uri: str, user: str, password: str):
self._driver: neo4j.AsyncDriver = AsyncGraphDatabase(uri, auth=(user, password))
loop = always_get_an_event_loop()
loop.run_until_complete(self._secure_table())
async def _secure_table(self):
async with self._driver.session() as session:
await session.run("CREATE CONSTRAINT ON (n:_id) ASSERT n._id IS UNIQUE;")
async def has_node(self, node_id: str) -> bool:
query = "MATCH (n) WHERE n._id = $node_id RETURN n IS NOT NULL AS nodeExists"

async with self._driver.session() as session:
result = await session.run(query, node_id=node_id)
record = await result.single()
if record:
return record["nodeExists"]
else:
return False

async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
query = (
"MATCH (n1)-[r]-(n2) "
"WHERE n1._id = $node1_id AND n2._id = $node2_id "
"RETURN COUNT(r) > 0 AS relationshipExists"
)

async with self._driver.session() as session:
result = await session.run(query, node1_id=source_node_id, node2_id=target_node_id)
record = await result.single()
if record:
return record["relationshipExists"]
else:
return False

async def node_degree(self, node_id: str) -> int:
query = (
"MATCH (n)-[r]-() "
"WHERE n._id = $node_id "
"RETURN count(r) AS degree"
)

async with self._driver.session() as session:
result = await session.run(query, node_id=node_id)
record = await result.single()
if record:
return record["degree"]
else:
return 0

async def edge_degree(self, src_id: str, tgt_id: str) -> int:
async with self._driver.session() as session:
src_degree = (await (await session.run("MATCH (n) WHERE n._id = $src_id RETURN size((n)--()) AS degree", src_id=src_id)).single())["degree"]
tgt_degree = (await (await session.run("MATCH (n) WHERE n._id = $tgt_id RETURN size((n)--()) AS degree", tgt_id=tgt_id)).single())["degree"]

return src_degree + tgt_degree

async def get_node(self, node_id: str) -> Union[dict, None]:
async with self._driver.session() as session:
result = await session.run("MATCH (n) WHERE n._id = $node_id RETURN n", node_id=node_id)
record = await result.single()
if record:
node = record["n"]
properties = dict(node)
return properties
else:
return None

async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
async with self._driver.session() as session:
result = await session.run("MATCH (start)-[r]-(end) WHERE id(start) = $start_node_id AND id(end) = $end_node_id RETURN r", start_node_id=source_node_id, end_node_id=target_node_id)
record = await result.single()
if not record:
return None
relationship = record["r"]
properties = dict(relationship)

return properties

async def get_node_edges(
self, source_node_id: str
) -> Union[list[tuple[str, str]], None]:
async with self._driver.session() as session:
result = await session.run("MATCH (startNode)-[]->(endNode) "
"WHERE id(startNode) = $start_node_id "
"RETURN endNode", start_node_id=source_node_id)
return [(source_node_id, record["endNode"]) for record in result]

async def upsert_node(self, node_id: str, node_data: dict[str, str]):
node_data['_id'] = node_id
query = "MERGE (n:Node {id: $_id}) SET n += $props RETURN id(n)"
async with self._driver.session() as session:
async with session.begin_transaction() as tx:
await tx.run(query, _id=node_data['_id'], props=node_data)

async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
async with self._driver.session() as session:
async with session.begin_transaction() as tx:
query = (
"MATCH (source), (target) "
"WHERE source.id = $source_node_id AND target.id = $target_node_id "
"MERGE (source)-[edge:YOUR_RELATIONSHIP_TYPE]->(target) "
"SET edge += $edge_data"
)
await tx.run(query, source_node_id=source_node_id, target_node_id=target_node_id, edge_data=edge_data)

async def clustering(self, algorithm: str):
raise NotImplementedError

async def community_schema(self) -> dict[str, SingleCommunitySchema]:
"""Return the community representation with report and nodes"""
raise NotImplementedError

async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
raise NotImplementedError("Node embedding is not used in nano-graphrag.")

0 comments on commit 32e9ddd

Please sign in to comment.