Skip to content

Commit

Permalink
Merge pull request #831 from neo4j-contrib/feature/resolve_subgraphs
Browse files Browse the repository at this point in the history
Added method to resolve a subgraph from a fetch_relations() call.
  • Loading branch information
mariusconjeaud authored Sep 24, 2024
2 parents 2060309 + 40a60fb commit cad5936
Show file tree
Hide file tree
Showing 5 changed files with 400 additions and 57 deletions.
3 changes: 2 additions & 1 deletion docker-scripts/docker-neo4j.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ docker run \
--env NEO4J_AUTH=neo4j/foobarbaz \
--env NEO4J_ACCEPT_LICENSE_AGREEMENT=yes \
--env NEO4JLABS_PLUGINS='["apoc"]' \
neo4j:$1
--rm \
neo4j:$1
148 changes: 120 additions & 28 deletions neomodel/async_/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Optional

from neomodel.async_.core import AsyncStructuredNode, adb
from neomodel.async_.relationship import AsyncStructuredRel
from neomodel.exceptions import MultipleNodesReturned
from neomodel.match_q import Q, QBase
from neomodel.properties import AliasProperty, ArrayProperty
Expand Down Expand Up @@ -414,16 +415,18 @@ def __init__(
self.lookup = lookup
self.additional_return = additional_return if additional_return else []
self.is_count = is_count
self.subgraph: dict = {}


class AsyncQueryBuilder:
def __init__(self, node_set):
def __init__(self, node_set, with_subgraph: bool = False):
self.node_set = node_set
self._ast = QueryAST()
self._query_params = {}
self._place_holder_registry = {}
self._ident_count = 0
self._node_counters = defaultdict(int)
self._with_subgraph: bool = with_subgraph

async def build_ast(self):
if hasattr(self.node_set, "relations_to_fetch"):
Expand Down Expand Up @@ -516,6 +519,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str:
stmt: str = ""
source_class_iterator = source_class
parts = path.split("__")
if self._with_subgraph:
subgraph = self._ast.subgraph
for index, part in enumerate(parts):
relationship = getattr(source_class_iterator, part)
# build source
Expand Down Expand Up @@ -549,6 +554,13 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str:
lhs_ident = stmt

rel_ident = self.create_ident()
if self._with_subgraph and part not in self._ast.subgraph:
subgraph[part] = {
"target": relationship.definition["node_class"],
"children": {},
"variable_name": rhs_name,
"rel_variable_name": rel_ident,
}
if relation["include_in_return"]:
self._additional_return(rel_ident)
stmt = _rel_helper(
Expand All @@ -559,6 +571,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str:
relation_type=relationship.definition["relation_type"],
)
source_class_iterator = relationship.definition["node_class"]
if self._with_subgraph:
subgraph = subgraph[part]["children"]

if relation.get("optional"):
self._ast.optional_match.append(stmt)
Expand Down Expand Up @@ -778,7 +792,7 @@ async def _contains(self, node_element_id):
self._query_params[place_holder] = node_element_id
return await self._count() >= 1

async def _execute(self, lazy=False):
async def _execute(self, lazy: bool = False, dict_output: bool = False):
if lazy:
# inject id() into return or return_set
if self._ast.return_clause:
Expand All @@ -791,9 +805,13 @@ async def _execute(self, lazy=False):
for item in self._ast.additional_return
]
query = self.build_query()
results, _ = await adb.cypher_query(
results, prop_names = await adb.cypher_query(
query, self._query_params, resolve_objects=True
)
if dict_output:
for item in results:
yield dict(zip(prop_names, item))
return
# The following is not as elegant as it could be but had to be copied from the
# version prior to cypher_query with the resolve_objects capability.
# It seems that certain calls are only supposed to be focusing to the first
Expand Down Expand Up @@ -1096,39 +1114,42 @@ def order_by(self, *props):

return self

def _add_relations(
self, *relation_names, include_in_return=True, **aliased_relation_names
def _register_relation_to_fetch(
self, relation_def: Any, alias: str = None, include_in_return: bool = True
):
"""Specify a set of relations to return."""
relations = []

def register_relation_to_fetch(relation_def: Any, alias: str = None):
if isinstance(relation_def, Optional):
item = {"path": relation_def.relation, "optional": True}
else:
item = {"path": relation_def}
item["include_in_return"] = include_in_return
if alias:
item["alias"] = alias
relations.append(item)
if isinstance(relation_def, Optional):
item = {"path": relation_def.relation, "optional": True}
else:
item = {"path": relation_def}
item["include_in_return"] = include_in_return
if alias:
item["alias"] = alias
return item

def fetch_relations(self, *relation_names):
"""Specify a set of relations to traverse and return."""
relations = []
for relation_name in relation_names:
register_relation_to_fetch(relation_name)
for alias, relation_def in aliased_relation_names.items():
register_relation_to_fetch(relation_def, alias)

relations.append(self._register_relation_to_fetch(relation_name))
self.relations_to_fetch = relations
return self

def fetch_relations(self, *relation_names, **aliased_relation_names):
"""Specify a set of relations to traverse and return."""
return self._add_relations(*relation_names, **aliased_relation_names)

def traverse_relations(self, *relation_names, **aliased_relation_names):
"""Specify a set of relations to traverse only."""
return self._add_relations(
*relation_names, include_in_return=False, **aliased_relation_names
)
relations = []
for relation_name in relation_names:
relations.append(
self._register_relation_to_fetch(relation_name, include_in_return=False)
)
for alias, relation_def in aliased_relation_names.items():
relations.append(
self._register_relation_to_fetch(
relation_def, alias, include_in_return=False
)
)

self.relations_to_fetch = relations
return self

def annotate(self, *vars, **aliased_vars):
"""Annotate node set results with extra variables."""
Expand All @@ -1146,6 +1167,77 @@ def register_extra_var(vardef, varname: str = None):

return self

def _to_subgraph(self, root_node, other_nodes, subgraph):
"""Recursive method to build root_node's relation graph from subgraph."""
root_node._relations = {}
for name, relation_def in subgraph.items():
for var_name, node in other_nodes.items():
if (
var_name
not in [
relation_def["variable_name"],
relation_def["rel_variable_name"],
]
or node is None
):
continue
if isinstance(node, list):
if len(node) > 0 and isinstance(node[0], AsyncStructuredRel):
name += "_relationship"
root_node._relations[name] = []
for item in node:
root_node._relations[name].append(
self._to_subgraph(
item, other_nodes, relation_def["children"]
)
)
else:
if isinstance(node, AsyncStructuredRel):
name += "_relationship"
root_node._relations[name] = self._to_subgraph(
node, other_nodes, relation_def["children"]
)

return root_node

async def resolve_subgraph(self) -> list:
"""
Convert every result contained in this node set to a subgraph.
By default, we receive results from neomodel as a list of
nodes without the hierarchy. This method tries to rebuild this
hierarchy without overriding anything in the node, that's why
we use a dedicated property to store node's relations.
"""
if not self.relations_to_fetch:
raise RuntimeError(
"Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()."
)
if not self.relations_to_fetch[0]["include_in_return"]:
raise NotImplementedError(
"You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead."
)
results: list = []
qbuilder = self.query_cls(self, with_subgraph=True)
await qbuilder.build_ast()
all_nodes = qbuilder._execute(dict_output=True)
other_nodes = {}
root_node = None
async for row in all_nodes:
for name, node in row.items():
if node.__class__ is self.source and "_" not in name:
root_node = node
else:
if isinstance(node, list) and isinstance(node[0], list):
other_nodes[name] = node[0]
else:
other_nodes[name] = node
results.append(
self._to_subgraph(root_node, other_nodes, qbuilder._ast.subgraph)
)
return results


class AsyncTraversal(AsyncBaseSet):
"""
Expand Down
Loading

0 comments on commit cad5936

Please sign in to comment.