Skip to content

Commit

Permalink
Improvements and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
tonioo committed Oct 22, 2024
1 parent 94afca2 commit 29f8bfa
Show file tree
Hide file tree
Showing 5 changed files with 267 additions and 100 deletions.
176 changes: 130 additions & 46 deletions neomodel/async_/match.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import inspect
import re
import string
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, List
from typing import Optional as TOptional
Expand Down Expand Up @@ -571,7 +570,6 @@ def build_traversal_from_path(
if relation.get("relation_filtering"):
rhs_name = rel_ident

Check warning on line 571 in neomodel/async_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/async_/match.py#L571

Added line #L571 was not covered by tests
else:
rel_reference = f'{relationship.definition["node_class"]}_{part}'
if index + 1 == len(parts) and "alias" in relation:
# If an alias is defined, use it to store the last hop in the path
rhs_name = relation["alias"]
Expand Down Expand Up @@ -794,20 +792,20 @@ def lookup_query_variable(
if traversals[0] not in subgraph:
return None
subgraph = subgraph[traversals[0]]
if len(traversals) == 1:
variable_to_return = f"{subgraph['rel_variable_name' if return_relation else 'variable_name']}"
return variable_to_return, subgraph["target"]
variable_to_return = ""

Check warning on line 798 in neomodel/async_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/async_/match.py#L798

Added line #L798 was not covered by tests
last_property = traversals[-1]
for part in traversals:
if part in subgraph["children"]:
subgraph = subgraph["children"][part]
elif part == last_property:
for part in traversals[1:]:
child = subgraph["children"].get(part)
if not child:
return None
subgraph = child
if part == last_property:

Check warning on line 805 in neomodel/async_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/async_/match.py#L800-L805

Added lines #L800 - L805 were not covered by tests
# if last part of prop is the last traversal
# we are safe to lookup the variable from the query
if return_relation:
variable_to_return = f"{subgraph['rel_variable_name']}"
else:
variable_to_return = f"{subgraph['variable_name']}"
else:
return None
variable_to_return = f"{subgraph['rel_variable_name' if return_relation else 'variable_name']}"
return variable_to_return, subgraph["target"]

Check warning on line 809 in neomodel/async_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/async_/match.py#L808-L809

Added lines #L808 - L809 were not covered by tests

def build_query(self) -> str:
Expand Down Expand Up @@ -844,6 +842,9 @@ def build_query(self) -> str:
for transform in self.node_set._intermediate_transforms:
query += " WITH "
injected_vars: list = []
# Reset return list since we'll probably invalidate most variables
self._ast.return_clause = ""
self._ast.additional_return = []
for name, source in transform["vars"].items():
if type(source) is str:
injected_vars.append(f"{source} AS {name}")
Expand All @@ -856,6 +857,13 @@ def build_query(self) -> str:
f"Unable to resolve variable name for relation {source.relation}."
)
injected_vars.append(f"{result[0]} AS {name}")
elif isinstance(source, NodeNameResolver):
result = self.lookup_query_variable(source.node)
if not result:
raise ValueError(

Check warning on line 863 in neomodel/async_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/async_/match.py#L860-L863

Added lines #L860 - L863 were not covered by tests
f"Unable to resolve variable name for node {source.node}."
)
injected_vars.append(f"{result[0]} AS {name}")

Check warning on line 866 in neomodel/async_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/async_/match.py#L866

Added line #L866 was not covered by tests
query += ",".join(injected_vars)
if not transform["ordering"]:
continue
Expand All @@ -876,6 +884,17 @@ def build_query(self) -> str:
for subquery, return_set in self.node_set._subqueries:
outer_primary_var = self._ast.return_clause
query += f" CALL {{ WITH {outer_primary_var} {subquery} }} "
for varname in return_set:
# We declare the returned variables as "virtual" relations of the
# root node class to make sure they will be translated by a call to
# resolve_subgraph() (otherwise, they will be lost).
# This is probably a temporary solution until we find something better...
self._ast.subgraph[varname] = {
"target": None, # We don't need target class in this use case
"children": {},
"variable_name": varname,
"rel_variable_name": varname,
}
returned_items += return_set

query += " RETURN "
Expand All @@ -884,12 +903,18 @@ def build_query(self) -> str:
if self._ast.additional_return:
returned_items += self._ast.additional_return
if hasattr(self.node_set, "_extra_results"):
for varname, vardef in self.node_set._extra_results.items():
for props in self.node_set._extra_results:
leftpart = props["vardef"].render(self)
varname = (
props["alias"]
if props.get("alias")
else props["vardef"].get_internal_name()
)
if varname in returned_items:
# We're about to override an existing variable, delete it first to
# avoid duplicate error
returned_items.remove(varname)
returned_items.append(f"{str(vardef)} AS {varname}")
returned_items.append(f"{leftpart} AS {varname}")

query += ", ".join(returned_items)

Expand Down Expand Up @@ -1062,10 +1087,62 @@ class Optional:


@dataclass
class AggregatingFunction:
class RelationNameResolver:
"""Helper to refer to a relation variable name.
Since variable names are generated automatically within MATCH statements (for
anything injected using fetch_relations or traverse_relations), we need a way to
retrieve them.
"""

relation: str


@dataclass
class NodeNameResolver:
"""Helper to refer to a node variable name.
Since variable names are generated automatically within MATCH statements (for
anything injected using fetch_relations or traverse_relations), we need a way to
retrieve them.
"""

node: str


@dataclass
class BaseFunction:
input_name: Union[str, "BaseFunction", NodeNameResolver, RelationNameResolver]

def __post_init__(self) -> None:
self._internal_name: str = ""

def get_internal_name(self) -> str:
return self._internal_name

def resolve_internal_name(self, qbuilder: AsyncQueryBuilder) -> str:
if isinstance(self.input_name, NodeNameResolver):
result = qbuilder.lookup_query_variable(self.input_name.node)

Check warning on line 1127 in neomodel/async_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/async_/match.py#L1127

Added line #L1127 was not covered by tests
elif isinstance(self.input_name, RelationNameResolver):
result = qbuilder.lookup_query_variable(self.input_name.relation, True)

Check warning on line 1129 in neomodel/async_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/async_/match.py#L1129

Added line #L1129 was not covered by tests
else:
result = (str(self.input_name), None)
if result is None:
raise ValueError(f"Unknown variable {self.input_name} used in Collect()")

Check warning on line 1133 in neomodel/async_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/async_/match.py#L1133

Added line #L1133 was not covered by tests
self._internal_name = result[0]
return self._internal_name

def render(self, qbuilder: AsyncQueryBuilder) -> str:
raise NotImplementedError

Check warning on line 1138 in neomodel/async_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/async_/match.py#L1138

Added line #L1138 was not covered by tests


@dataclass
class AggregatingFunction(BaseFunction):
"""Base aggregating function class."""

input_name: str
pass


@dataclass
Expand All @@ -1074,38 +1151,33 @@ class Collect(AggregatingFunction):

distinct: bool = False

def __str__(self):
def render(self, qbuilder: AsyncQueryBuilder) -> str:
varname = self.resolve_internal_name(qbuilder)
if self.distinct:
return f"collect(DISTINCT {self.input_name})"
return f"collect({self.input_name})"
return f"collect(DISTINCT {varname})"
return f"collect({varname})"


@dataclass
class ScalarFunction:
class ScalarFunction(BaseFunction):
"""Base scalar function class."""

input_name: Union[str, AggregatingFunction]
pass


@dataclass
class Last(ScalarFunction):
"""last() function."""

def __str__(self) -> str:
return f"last({str(self.input_name)})"


@dataclass
class RelationNameResolver:
"""Helper to refer to a relation variable name.
Since variable names are generated automatically within MATCH statements (for
anything injected using fetch_relations or traverse_relations), we need a way to
retrieve them.
"""

relation: str
def render(self, qbuilder: AsyncQueryBuilder) -> str:
if isinstance(self.input_name, str):
content = str(self.input_name)

Check warning on line 1174 in neomodel/async_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/async_/match.py#L1174

Added line #L1174 was not covered by tests
elif isinstance(self.input_name, BaseFunction):
content = self.input_name.render(qbuilder)
self._internal_name = self.input_name.get_internal_name()
else:
content = self.resolve_internal_name(qbuilder)

Check warning on line 1179 in neomodel/async_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/async_/match.py#L1179

Added line #L1179 was not covered by tests
return f"last({content})"


@dataclass
Expand Down Expand Up @@ -1156,7 +1228,7 @@ def __init__(self, source) -> None:
self.dont_match: Dict = {}

self.relations_to_fetch: List = []
self._extra_results: dict = {}
self._extra_results: List = []
self._subqueries: list[Tuple[str, list[str]]] = []
self._intermediate_transforms: list = []

Expand Down Expand Up @@ -1357,7 +1429,9 @@ def annotate(self, *vars, **aliased_vars):

def register_extra_var(vardef, varname: Union[str, None] = None):
if isinstance(vardef, (AggregatingFunction, ScalarFunction)):
self._extra_results[varname if varname else vardef.input_name] = vardef
self._extra_results.append(
{"vardef": vardef, "alias": varname if varname else ""}
)
else:
raise NotImplementedError

Expand Down Expand Up @@ -1411,17 +1485,20 @@ async def resolve_subgraph(self) -> list:
we use a dedicated property to store node's relations.
"""
if not self.relations_to_fetch:
raise RuntimeError(
"Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()."
)
if not self.relations_to_fetch[0]["include_in_return"]:
if (
self.relations_to_fetch
and not self.relations_to_fetch[0]["include_in_return"]
):
raise NotImplementedError(
"You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead."
)
results: list = []
qbuilder = self.query_cls(self)
await qbuilder.build_ast()
if not qbuilder._ast.subgraph:
raise RuntimeError(
"Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()."
)
all_nodes = qbuilder._execute(dict_output=True)
other_nodes = {}
root_node = None
Expand Down Expand Up @@ -1454,7 +1531,8 @@ async def subquery(
if (
var != qbuilder._ast.return_clause
and var not in qbuilder._ast.additional_return
and var not in nodeset._extra_results
and var
not in [res["alias"] for res in nodeset._extra_results if res["alias"]]
):
raise RuntimeError(f"Variable '{var}' is not returned by subquery.")
self._subqueries.append((qbuilder.build_query(), return_set))
Expand All @@ -1463,10 +1541,16 @@ async def subquery(
def intermediate_transform(
self, vars: Dict[str, Any], ordering: TOptional[list] = None
) -> "AsyncNodeSet":
if not vars:
raise ValueError(

Check warning on line 1545 in neomodel/async_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/async_/match.py#L1545

Added line #L1545 was not covered by tests
"You must provide one variable at least when calling intermediate_transform()"
)
for name, source in vars.items():
if type(source) is not str and not isinstance(source, RelationNameResolver):
if type(source) is not str and not isinstance(
source, (NodeNameResolver, RelationNameResolver)
):
raise ValueError(
f"Wrong source type specified for variable '{name}', should be a string or an instance of RelationNameResolver"
f"Wrong source type specified for variable '{name}', should be a string or an instance of NodeNameResolver or RelationNameResolver"
)
self._intermediate_transforms.append({"vars": vars, "ordering": ordering})
return self
Expand Down
Loading

0 comments on commit 29f8bfa

Please sign in to comment.