Skip to content

Commit

Permalink
chore: make graph stat computation asynchronous
Browse files Browse the repository at this point in the history
  • Loading branch information
ClemDoum committed Feb 6, 2024
1 parent 54fb112 commit a18ad48
Show file tree
Hide file tree
Showing 12 changed files with 215 additions and 89 deletions.
21 changes: 5 additions & 16 deletions neo4j-app/neo4j_app/app/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from neo4j_app.app import ServiceConfig
from neo4j_app.app.dependencies import lifespan_neo4j_driver
from neo4j_app.app.doc import DOC_GRAPH_DUMP, DOC_GRAPH_DUMP_DESC, GRAPH_TAG
from neo4j_app.core.neo4j.graphs import count_documents_and_named_entities, dump_graph
from neo4j_app.core.neo4j.graphs import dump_graph, project_statistics
from neo4j_app.core.objects import DumpRequest, GraphCounts
from neo4j_app.core.utils.logging import log_elapsed_time_cm

Expand Down Expand Up @@ -50,26 +50,15 @@ async def _graph_dump(
return res

@router.get("/counts", response_model=GraphCounts)
async def _count_documents_and_named_entities(
project: str, request: Request
) -> GraphCounts:
config: ServiceConfig = request.app.state.config
if config.supports_neo4j_parallel_runtime is None:
msg = (
"parallel support has not been set, config has not been properly"
" initialized using AppConfig.with_neo4j_support"
)
raise ValueError(msg)
async def _count_documents_and_named_entities(project: str) -> GraphCounts:
with log_elapsed_time_cm(
logger,
logging.INFO,
"Counted documents and named entities in {elapsed_time} !",
):
count = await count_documents_and_named_entities(
project=project,
neo4j_driver=lifespan_neo4j_driver(),
parallel=config.supports_neo4j_parallel_runtime,
stats = await project_statistics(
project=project, neo4j_driver=lifespan_neo4j_driver()
)
return count
return stats.counts

return router
5 changes: 5 additions & 0 deletions neo4j-app/neo4j_app/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@
NEO4J_CSV_START_ID = ":START_ID"
NEO4J_CSV_TYPE = ":TYPE"

STATS_NODE = "_ProjectStatistics"
STATS_N_DOCS = "nDocuments"
STATS_N_ENTS = "nEntities"
STATS_ID = "id"

TASK_NODE = "_Task"
TASK_COMPLETED_AT = "completedAt"
TASK_CREATED_AT = "createdAt"
Expand Down
1 change: 0 additions & 1 deletion neo4j-app/neo4j_app/core/elasticsearch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
SIZE,
SLICE,
SORT,
SOURCE,
match_all,
)
from neo4j_app.core.neo4j import Neo4jImportWorker, write_neo4j_csv
Expand Down
22 changes: 14 additions & 8 deletions neo4j-app/neo4j_app/core/neo4j/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,44 +16,50 @@
migration_v_0_5_0_tx,
migration_v_0_6_0,
migration_v_0_7_0_tx,
migration_v_0_8_0,
)

V_0_1_0 = Migration(
version="0.1.0",
label="Create migration and project and constraints",
label="create migration and project and constraints",
migration_fn=migration_v_0_1_0_tx,
)
V_0_2_0 = Migration(
version="0.2.0",
label="Create doc and named entities index and constraints",
label="create doc and named entities index and constraints",
migration_fn=migration_v_0_2_0_tx,
)
V_0_3_0 = Migration(
version="0.3.0",
label="Create tasks indexes and constraints",
label="create tasks indexes and constraints",
migration_fn=migration_v_0_3_0_tx,
)
V_0_4_0 = Migration(
version="0.4.0",
label="Create document path and content type indexes",
label="create document path and content type indexes",
migration_fn=migration_v_0_4_0_tx,
)
V_0_5_0 = Migration(
version="0.5.0",
label="Create email user and domain indexes",
label="create email user and domain indexes",
migration_fn=migration_v_0_5_0_tx,
)
V_0_6_0 = Migration(
version="0.6.0",
label="Add mention counts to named entity document relationships",
label="add mention counts to named entity document relationships",
migration_fn=migration_v_0_6_0,
)
V_0_7_0 = Migration(
version="0.7.0",
label="Create document modified and created at indexes",
label="create document modified and created at indexes",
migration_fn=migration_v_0_7_0_tx,
)
MIGRATIONS = [V_0_1_0, V_0_2_0, V_0_3_0, V_0_4_0, V_0_5_0, V_0_6_0, V_0_7_0]
V_0_8_0 = Migration(
version="0.8.0",
label="compute project stats and create stats unique constraint",
migration_fn=migration_v_0_8_0,
)
MIGRATIONS = [V_0_1_0, V_0_2_0, V_0_3_0, V_0_4_0, V_0_5_0, V_0_6_0, V_0_7_0, V_0_8_0]


def get_neo4j_csv_reader(
Expand Down
78 changes: 57 additions & 21 deletions neo4j-app/neo4j_app/core/neo4j/graphs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from copy import deepcopy
from typing import AsyncGenerator, Optional
from typing import AsyncGenerator, Dict, Optional

import neo4j

Expand All @@ -14,7 +14,7 @@
NE_NODE,
)
from neo4j_app.core.neo4j.projects import project_db
from neo4j_app.core.objects import DumpFormat, GraphCounts
from neo4j_app.core.objects import DumpFormat, ProjectStatistics

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -135,30 +135,66 @@ async def _dump_subgraph_to_cypher(
yield rec["cypherStatements"]


async def count_documents_and_named_entities(
neo4j_driver: neo4j.AsyncDriver, project: str, parallel: bool
) -> GraphCounts:
async def project_statistics(
neo4j_driver: neo4j.AsyncDriver, project: str
) -> ProjectStatistics:
neo4j_db = await project_db(neo4j_driver, project)
async with neo4j_driver.session(database=neo4j_db) as sess:
count = await sess.execute_read(
_count_documents_and_named_entities_tx, parallel=parallel
)
return count
stats = await sess.execute_read(ProjectStatistics.from_neo4j)
return stats


async def _count_documents_and_named_entities_tx(
tx: neo4j.AsyncTransaction, parallel: bool
) -> GraphCounts:
runtime = "CYPHER runtime=parallel" if parallel else ""
doc_query = f"""{runtime}
MATCH (doc:{DOC_NODE}) RETURN count(*) as nDocs
"""
async def refresh_project_statistics(
neo4j_driver: neo4j.AsyncDriver, project: str
) -> ProjectStatistics:
neo4j_db = await project_db(neo4j_driver, project)
async with neo4j_driver.session(database=neo4j_db) as sess:
stats = await sess.execute_write(refresh_project_statistics_tx)
return stats


async def _count_documents_tx(
tx: neo4j.AsyncTransaction, document_counts_key="nDocs"
) -> int:
doc_query = f"""
MATCH (doc:{DOC_NODE}) RETURN count(*) as nDocs
"""
doc_res = await tx.run(doc_query)
entity_query = f"""{runtime}
MATCH (ne:{NE_NODE})
WITH DISTINCT labels(ne) as neLabels, ne
doc_res = await doc_res.single()
n_docs = doc_res[document_counts_key]
return n_docs


async def _count_entities_tx(
tx: neo4j.AsyncTransaction,
entity_labels_key: str = "neLabels",
entity_counts_key: str = "nMentions",
) -> Dict[str, int]:
entity_query = f"""MATCH (ne:{NE_NODE})
WITH ne, labels(ne) as neLabels
MATCH (ne)-[rel:{NE_APPEARS_IN_DOC}]->()
RETURN neLabels, sum(rel.{NE_MENTION_COUNT}) as nMentions"""
entity_res = await tx.run(entity_query)
count = await GraphCounts.from_neo4j(doc_res=doc_res, entity_res=entity_res)
return count
n_ents = dict()
async for rec in entity_res:
labels = [l for l in rec[entity_labels_key] if l != NE_NODE]
if len(labels) != 1:
msg = (
"Expected named entity to have exactly 2 labels."
" Refactor this function."
)
raise ValueError(msg)
n_ents[labels[0]] = rec[entity_counts_key]
return n_ents


async def refresh_project_statistics_tx(
tx: neo4j.AsyncTransaction,
) -> ProjectStatistics:
# We could update the stats directly in DB, however since _count_entities_tx needs
# to perform advanced error handling, we quickly get back to Python before
# re-writing the whole stats
n_docs = await _count_documents_tx(tx)
n_ents = await _count_entities_tx(tx)
stats = await ProjectStatistics.to_neo4j_tx(tx, n_docs, n_ents)
return stats
4 changes: 2 additions & 2 deletions neo4j-app/neo4j_app/core/neo4j/migrations/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,15 @@ async def _migrate_with_lock(
# the DB...

# Lock the DB first, raising in case a migration already exists
logger.debug("Trying to run migration %s...", migration.label)
logger.debug("Trying to run migration to %s...", migration.label)
await registry_session.execute_write(
create_migration_tx,
project=project,
migration_version=str(migration.version),
migration_label=migration.label,
)
# Then run to migration
logger.debug("Acquired write lock for %s !", migration.label)
logger.debug("Acquired write lock to %s !", migration.label)
sig = signature(migration.migration_fn)
first_param = list(sig.parameters)[0]
if first_param == "tx":
Expand Down
32 changes: 32 additions & 0 deletions neo4j-app/neo4j_app/core/neo4j/migrations/migrations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import neo4j

from neo4j_app.constants import (
Expand All @@ -20,6 +22,8 @@
NE_NODE,
PROJECT_NAME,
PROJECT_NODE,
STATS_ID,
STATS_NODE,
TASK_CREATED_AT,
TASK_ERROR_ID,
TASK_ERROR_NODE,
Expand All @@ -31,6 +35,9 @@
TASK_NODE,
TASK_TYPE,
)
from neo4j_app.core.neo4j.graphs import refresh_project_statistics_tx

logger = logging.getLogger(__name__)


async def migration_v_0_1_0_tx(tx: neo4j.AsyncTransaction):
Expand Down Expand Up @@ -68,6 +75,11 @@ async def migration_v_0_7_0_tx(tx: neo4j.AsyncTransaction):
await _create_document_created_and_modified_at_indexes(tx)


async def migration_v_0_8_0(sess: neo4j.AsyncSession):
await sess.execute_write(_create_project_stats_unique_constraint_tx)
await sess.execute_write(refresh_project_statistics_if_needed_tx)


async def _create_document_and_ne_id_unique_constraint_tx(tx: neo4j.AsyncTransaction):
doc_query = f"""CREATE CONSTRAINT constraint_document_unique_id
IF NOT EXISTS
Expand Down Expand Up @@ -189,3 +201,23 @@ async def _create_document_created_and_modified_at_indexes(tx: neo4j.AsyncTransa
FOR (doc:{DOC_NODE})
ON (doc.{DOC_MODIFIED_AT})"""
await tx.run(modified_at_index)


async def _create_project_stats_unique_constraint_tx(tx: neo4j.AsyncTransaction):
stats_query = f"""CREATE CONSTRAINT constraint_stats_unique_id
IF NOT EXISTS
FOR (s:{STATS_NODE})
REQUIRE (s.{STATS_ID}) IS UNIQUE
"""
await tx.run(stats_query)


async def refresh_project_statistics_if_needed_tx(tx: neo4j.AsyncTransaction):
count_query = f"MATCH (s:{STATS_NODE}) RETURN s"
res = await tx.run(count_query)
counts = await res.single()
if counts is None:
logger.info("missing graph statistics, computing them...")
await refresh_project_statistics_tx(tx)
else:
logger.info("stats are already computed skipping !")
69 changes: 45 additions & 24 deletions neo4j-app/neo4j_app/core/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import json
from datetime import datetime
from enum import Enum, unique
from typing import Any, Dict, List, Optional, Union
from typing import Any, ClassVar, Dict, List, Optional, Union

import neo4j
from pydantic import Field

from neo4j_app.constants import NE_NODE
from neo4j_app.constants import STATS_ID, STATS_NODE, STATS_N_DOCS, STATS_N_ENTS
from neo4j_app.core.utils.pydantic import LowerCamelCaseModel, NoEnumModel
from neo4j_app.icij_worker.task import Task, TaskStatus

Expand Down Expand Up @@ -40,29 +40,50 @@ class GraphCounts(LowerCamelCaseModel):
documents: int = 0
named_entities: Dict[str, int] = Field(default_factory=dict)


class ProjectStatistics(LowerCamelCaseModel):
singleton_stat_id: ClassVar[str] = Field(
default="project-stats-singleton-id", const=True
)
counts: GraphCounts = Field(default_factory=GraphCounts)

@classmethod
async def from_neo4j(
cls,
*,
doc_res: neo4j.AsyncResult,
entity_res: neo4j.AsyncResult,
document_counts_key="nDocs",
entity_labels_key="neLabels",
entity_counts_key="nMentions",
) -> GraphCounts:
doc_res = await doc_res.single()
n_docs = doc_res[document_counts_key]
n_ents = dict()
async for rec in entity_res:
labels = [l for l in rec[entity_labels_key] if l != NE_NODE]
if len(labels) != 1:
msg = (
"Expected named entity to have exactly 2 labels."
" Refactor this function."
)
raise ValueError(msg)
n_ents[labels[0]] = rec[entity_counts_key]
return GraphCounts(documents=n_docs, named_entities=n_ents)
async def from_neo4j(cls, tx: neo4j.AsyncTransaction) -> ProjectStatistics:
query = f"MATCH (stats:{STATS_NODE}) RETURN *"
stats_res = await tx.run(query)
stats = [s async for s in stats_res]
if not stats:
return ProjectStatistics()
if len(stats) > 1:
raise ValueError("Inconsistent state, found several project statistics")
stats = stats[0]["stats"]
ent_counts_as_list = stats[STATS_N_ENTS]
ent_counts = dict()
for ent_ix in range(0, len(ent_counts_as_list), 2):
ent_count_ix = ent_ix + 1
ent_counts[ent_counts_as_list[ent_ix]] = int(
ent_counts_as_list[ent_count_ix]
)
counts = GraphCounts(documents=stats[STATS_N_DOCS], named_entities=ent_counts)
return ProjectStatistics(counts=counts)

@classmethod
async def to_neo4j_tx(
cls, tx: neo4j.AsyncTransaction, doc_count: int, ent_counts: Dict[str, int]
) -> ProjectStatistics:
query = f"""MERGE (s:{STATS_NODE} {{ {STATS_ID}: $singletonId }})
SET s.{STATS_N_DOCS} = $docCount, s.{STATS_N_ENTS} = $entCounts"""
ent_counts_as_list = [
entry for k, v in ent_counts.items() for entry in (k, str(v))
]
await tx.run(
query,
singletonId=cls.singleton_stat_id.default,
docCount=doc_count,
entCounts=ent_counts_as_list,
)
counts = GraphCounts(documents=doc_count, named_entities=ent_counts)
return cls(counts=counts)


class Neo4jCSVRequest(LowerCamelCaseModel):
Expand Down
Loading

0 comments on commit a18ad48

Please sign in to comment.