Skip to content

Commit

Permalink
change name & fix order
Browse files Browse the repository at this point in the history
  • Loading branch information
Dorbmon committed Sep 17, 2024
1 parent 93366a2 commit 0b4013d
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 10 deletions.
9 changes: 5 additions & 4 deletions nano_graphrag/storage/asyncpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import nest_asyncio
nest_asyncio.apply()

class AsyncpgVectorStorage(BaseVectorStorage):
class AsyncPGVectorStorage(BaseVectorStorage):
table_name_generator: callable = None
conn_fetcher: callable = None
cosine_better_than_threshold: float = 0.2
Expand All @@ -34,15 +34,16 @@ def __init__(self, dsn: str = None, conn_fetcher: callable = None, table_name_ge
loop = always_get_an_event_loop()
loop.run_until_complete(self._secure_table())
@asynccontextmanager
async def __get_conn(self):
async def __get_conn(self, vector_register=True):
try:
conn: asyncpg.Connection = await asyncpg.connect(self.dsn)
await register_vector(conn)
if vector_register:
await register_vector(conn)
yield conn
finally:
await conn.close()
async def _secure_table(self):
async with self.conn_fetcher() as conn:
async with self.conn_fetcher(register_vector=False) as conn:
conn: asyncpg.Connection
await conn.execute('CREATE EXTENSION IF NOT EXISTS vector')
result = await conn.fetch(
Expand Down
129 changes: 129 additions & 0 deletions nano_graphrag/storage/neo4j.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from nano_graphrag._storage import BaseGraphStorage
from neo4j import AsyncGraphDatabase
import neo4j
from typing import Union
from nano_graphrag.graphrag import always_get_an_event_loop
from nano_graphrag.base import SingleCommunitySchema
import numpy as np

import nest_asyncio
nest_asyncio.apply()

class NetworkXStorage(BaseGraphStorage):
def __init__(self, uri: str, user: str, password: str):
self._driver: neo4j.AsyncDriver = AsyncGraphDatabase(uri, auth=(user, password))
loop = always_get_an_event_loop()
loop.run_until_complete(self._secure_table())
async def _secure_table(self):
async with self._driver.session() as session:
await session.run("CREATE CONSTRAINT ON (n:_id) ASSERT n._id IS UNIQUE;")
async def has_node(self, node_id: str) -> bool:
query = "MATCH (n) WHERE n._id = $node_id RETURN n IS NOT NULL AS nodeExists"

async with self._driver.session() as session:
result = await session.run(query, node_id=node_id)
record = await result.single()
if record:
return record["nodeExists"]
else:
return False

async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
query = (
"MATCH (n1)-[r]-(n2) "
"WHERE n1._id = $node1_id AND n2._id = $node2_id "
"RETURN COUNT(r) > 0 AS relationshipExists"
)

async with self._driver.session() as session:
result = await session.run(query, node1_id=source_node_id, node2_id=target_node_id)
record = await result.single()
if record:
return record["relationshipExists"]
else:
return False

async def node_degree(self, node_id: str) -> int:
query = (
"MATCH (n)-[r]-() "
"WHERE n._id = $node_id "
"RETURN count(r) AS degree"
)

async with self._driver.session() as session:
result = await session.run(query, node_id=node_id)
record = await result.single()
if record:
return record["degree"]
else:
return 0

async def edge_degree(self, src_id: str, tgt_id: str) -> int:
async with self._driver.session() as session:
src_degree = (await (await session.run("MATCH (n) WHERE n._id = $src_id RETURN size((n)--()) AS degree", src_id=src_id)).single())["degree"]
tgt_degree = (await (await session.run("MATCH (n) WHERE n._id = $tgt_id RETURN size((n)--()) AS degree", tgt_id=tgt_id)).single())["degree"]

return src_degree + tgt_degree

async def get_node(self, node_id: str) -> Union[dict, None]:
async with self._driver.session() as session:
result = await session.run("MATCH (n) WHERE n._id = $node_id RETURN n", node_id=node_id)
record = await result.single()
if record:
node = record["n"]
properties = dict(node)
return properties
else:
return None

async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
async with self._driver.session() as session:
result = await session.run("MATCH (start)-[r]-(end) WHERE id(start) = $start_node_id AND id(end) = $end_node_id RETURN r", start_node_id=source_node_id, end_node_id=target_node_id)
record = await result.single()
if not record:
return None
relationship = record["r"]
properties = dict(relationship)

return properties

async def get_node_edges(
self, source_node_id: str
) -> Union[list[tuple[str, str]], None]:
async with self._driver.session() as session:
result = await session.run("MATCH (startNode)-[]->(endNode) "
"WHERE id(startNode) = $start_node_id "
"RETURN endNode", start_node_id=source_node_id)
return [(source_node_id, record["endNode"]) for record in result]

async def upsert_node(self, node_id: str, node_data: dict[str, str]):
node_data['_id'] = node_id
query = "MERGE (n:Node {id: $_id}) SET n += $props RETURN id(n)"
async with self._driver.session() as session:
async with session.begin_transaction() as tx:
await tx.run(query, _id=node_data['_id'], props=node_data)

async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
async with self._driver.session() as session:
async with session.begin_transaction() as tx:
query = (
"MATCH (source), (target) "
"WHERE source.id = $source_node_id AND target.id = $target_node_id "
"MERGE (source)-[edge:YOUR_RELATIONSHIP_TYPE]->(target) "
"SET edge += $edge_data"
)
await tx.run(query, source_node_id=source_node_id, target_node_id=target_node_id, edge_data=edge_data)

async def clustering(self, algorithm: str):
raise NotImplementedError

async def community_schema(self) -> dict[str, SingleCommunitySchema]:
"""Return the community representation with report and nodes"""
raise NotImplementedError

async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
raise NotImplementedError("Node embedding is not used in nano-graphrag.")
12 changes: 6 additions & 6 deletions tests/test_asyncpg_vector_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from nano_graphrag import GraphRAG
from nano_graphrag._utils import wrap_embedding_func_with_attrs

from nano_graphrag.storage.asyncpg import AsyncpgVectorStorage
from nano_graphrag.storage.asyncpg import AsyncPGVectorStorage
import asyncpg
from nano_graphrag.graphrag import always_get_an_event_loop
import os
Expand Down Expand Up @@ -36,7 +36,7 @@ async def mock_embedding(texts: list[str]) -> np.ndarray:
@pytest.fixture
def asyncpg_storage(setup_teardown):
rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding)
return AsyncpgVectorStorage(
return AsyncPGVectorStorage(
namespace="test",
global_config=asdict(rag),
embedding_func=mock_embedding,
Expand Down Expand Up @@ -67,7 +67,7 @@ async def test_upsert_and_query(asyncpg_storage):
@pytest.mark.asyncio
async def test_persistence(setup_teardown):
rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding)
initial_storage = AsyncpgVectorStorage(
initial_storage = AsyncPGVectorStorage(
namespace="test",
global_config=asdict(rag),
embedding_func=mock_embedding,
Expand All @@ -82,7 +82,7 @@ async def test_persistence(setup_teardown):
await initial_storage.upsert(test_data)
await initial_storage.index_done_callback()

new_storage = AsyncpgVectorStorage(
new_storage = AsyncPGVectorStorage(
namespace="test",
global_config=asdict(rag),
embedding_func=mock_embedding,
Expand All @@ -100,7 +100,7 @@ async def test_persistence(setup_teardown):
@pytest.mark.asyncio
async def test_persistence_large_dataset(setup_teardown):
rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding)
initial_storage = AsyncpgVectorStorage(
initial_storage = AsyncPGVectorStorage(
namespace="test_large",
global_config=asdict(rag),
embedding_func=mock_embedding,
Expand All @@ -115,7 +115,7 @@ async def test_persistence_large_dataset(setup_teardown):
await initial_storage.upsert(large_data)
await initial_storage.index_done_callback()

new_storage = AsyncpgVectorStorage(
new_storage = AsyncPGVectorStorage(
namespace="test_large",
global_config=asdict(rag),
embedding_func=mock_embedding,
Expand Down

0 comments on commit 0b4013d

Please sign in to comment.