-
Notifications
You must be signed in to change notification settings - Fork 160
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
133 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
from nano_graphrag._storage import BaseGraphStorage | ||
from neo4j import AsyncGraphDatabase | ||
import neo4j | ||
from typing import Union | ||
from nano_graphrag.graphrag import always_get_an_event_loop | ||
from nano_graphrag.base import SingleCommunitySchema | ||
import numpy as np | ||
|
||
import nest_asyncio | ||
nest_asyncio.apply() | ||
|
||
class NetworkXStorage(BaseGraphStorage): | ||
def __init__(self, uri: str, user: str, password: str): | ||
self._driver: neo4j.AsyncDriver = AsyncGraphDatabase(uri, auth=(user, password)) | ||
loop = always_get_an_event_loop() | ||
loop.run_until_complete(self._secure_table()) | ||
async def _secure_table(self): | ||
async with self._driver.session() as session: | ||
await session.run("CREATE CONSTRAINT ON (n:_id) ASSERT n._id IS UNIQUE;") | ||
async def has_node(self, node_id: str) -> bool: | ||
query = "MATCH (n) WHERE n._id = $node_id RETURN n IS NOT NULL AS nodeExists" | ||
|
||
async with self._driver.session() as session: | ||
result = await session.run(query, node_id=node_id) | ||
record = await result.single() | ||
if record: | ||
return record["nodeExists"] | ||
else: | ||
return False | ||
|
||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: | ||
query = ( | ||
"MATCH (n1)-[r]-(n2) " | ||
"WHERE n1._id = $node1_id AND n2._id = $node2_id " | ||
"RETURN COUNT(r) > 0 AS relationshipExists" | ||
) | ||
|
||
async with self._driver.session() as session: | ||
result = await session.run(query, node1_id=source_node_id, node2_id=target_node_id) | ||
record = await result.single() | ||
if record: | ||
return record["relationshipExists"] | ||
else: | ||
return False | ||
|
||
async def node_degree(self, node_id: str) -> int: | ||
query = ( | ||
"MATCH (n)-[r]-() " | ||
"WHERE n._id = $node_id " | ||
"RETURN count(r) AS degree" | ||
) | ||
|
||
async with self._driver.session() as session: | ||
result = await session.run(query, node_id=node_id) | ||
record = await result.single() | ||
if record: | ||
return record["degree"] | ||
else: | ||
return 0 | ||
|
||
async def edge_degree(self, src_id: str, tgt_id: str) -> int: | ||
async with self._driver.session() as session: | ||
src_degree = (await (await session.run("MATCH (n) WHERE n._id = $src_id RETURN size((n)--()) AS degree", src_id=src_id)).single())["degree"] | ||
tgt_degree = (await (await session.run("MATCH (n) WHERE n._id = $tgt_id RETURN size((n)--()) AS degree", tgt_id=tgt_id)).single())["degree"] | ||
|
||
return src_degree + tgt_degree | ||
|
||
async def get_node(self, node_id: str) -> Union[dict, None]: | ||
async with self._driver.session() as session: | ||
result = await session.run("MATCH (n) WHERE n._id = $node_id RETURN n", node_id=node_id) | ||
record = await result.single() | ||
if record: | ||
node = record["n"] | ||
properties = dict(node) | ||
return properties | ||
else: | ||
return None | ||
|
||
async def get_edge( | ||
self, source_node_id: str, target_node_id: str | ||
) -> Union[dict, None]: | ||
async with self._driver.session() as session: | ||
result = await session.run("MATCH (start)-[r]-(end) WHERE id(start) = $start_node_id AND id(end) = $end_node_id RETURN r", start_node_id=source_node_id, end_node_id=target_node_id) | ||
record = await result.single() | ||
if not record: | ||
return None | ||
relationship = record["r"] | ||
properties = dict(relationship) | ||
|
||
return properties | ||
|
||
async def get_node_edges( | ||
self, source_node_id: str | ||
) -> Union[list[tuple[str, str]], None]: | ||
async with self._driver.session() as session: | ||
result = await session.run("MATCH (startNode)-[]->(endNode) " | ||
"WHERE id(startNode) = $start_node_id " | ||
"RETURN endNode", start_node_id=source_node_id) | ||
return [(source_node_id, record["endNode"]) for record in result] | ||
|
||
async def upsert_node(self, node_id: str, node_data: dict[str, str]): | ||
node_data['_id'] = node_id | ||
query = "MERGE (n:Node {id: $_id}) SET n += $props RETURN id(n)" | ||
async with self._driver.session() as session: | ||
async with session.begin_transaction() as tx: | ||
await tx.run(query, _id=node_data['_id'], props=node_data) | ||
|
||
async def upsert_edge( | ||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] | ||
): | ||
async with self._driver.session() as session: | ||
async with session.begin_transaction() as tx: | ||
query = ( | ||
"MATCH (source), (target) " | ||
"WHERE source.id = $source_node_id AND target.id = $target_node_id " | ||
"MERGE (source)-[edge:YOUR_RELATIONSHIP_TYPE]->(target) " | ||
"SET edge += $edge_data" | ||
) | ||
await tx.run(query, source_node_id=source_node_id, target_node_id=target_node_id, edge_data=edge_data) | ||
|
||
async def clustering(self, algorithm: str): | ||
raise NotImplementedError | ||
|
||
async def community_schema(self) -> dict[str, SingleCommunitySchema]: | ||
"""Return the community representation with report and nodes""" | ||
raise NotImplementedError | ||
|
||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: | ||
raise NotImplementedError("Node embedding is not used in nano-graphrag.") |