From 1dc380b8f1ca4eea825bf27a65da7fb3f298918e Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Fri, 20 Sep 2024 16:44:13 +0200 Subject: [PATCH 01/42] Added support for annotations and calling aggregating functions. Also introduced a new traverse_relations() method which only return the first hop of the path in the return clause. --- neomodel/async_/match.py | 112 ++++++++++++++++++++++++++++------ neomodel/sync_/match.py | 112 ++++++++++++++++++++++++++++------ test/async_/test_match_api.py | 42 +++++++++++++ test/sync_/test_match_api.py | 43 ++++++++++++- 4 files changed, 268 insertions(+), 41 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 7f0435fe..c203dcc2 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -2,7 +2,7 @@ import re from collections import defaultdict from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional from neomodel.async_.core import AsyncStructuredNode, adb from neomodel.exceptions import MultipleNodesReturned @@ -507,7 +507,7 @@ async def build_traversal(self, traversal): return traversal_ident - def _additional_return(self, name): + def _additional_return(self, name: str): if name not in self._ast.additional_return and name != self._ast.return_clause: self._ast.additional_return.append(name) @@ -515,7 +515,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: path: str = relation["path"] stmt: str = "" source_class_iterator = source_class - for index, part in enumerate(path.split("__")): + parts = path.split("__") + for index, part in enumerate(parts): relationship = getattr(source_class_iterator, part) # build source if "node_class" not in relationship.definition: @@ -523,11 +524,16 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: rhs_label = relationship.definition["node_class"].__label__ rel_reference = f'{relationship.definition["node_class"]}_{part}' self._node_counters[rel_reference] += 1 - rhs_name = ( - f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}" - ) + if index + 1 == len(parts) and "alias" in relation: + # If an alias is defined, use it to store the last hop in the path + rhs_name = relation["alias"] + else: + rhs_name = ( + f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}" + ) rhs_ident = f"{rhs_name}:{rhs_label}" - self._additional_return(rhs_name) + if relation["include_in_return"]: + self._additional_return(rhs_name) if not stmt: lhs_label = source_class_iterator.__label__ lhs_name = lhs_label.lower() @@ -537,13 +543,14 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: # contains the primary node so _contains() works # as usual self._ast.return_clause = lhs_name - else: + elif relation["include_in_return"]: self._additional_return(lhs_name) else: lhs_ident = stmt rel_ident = self.create_ident() - self._additional_return(rel_ident) + if relation["include_in_return"]: + self._additional_return(rel_ident) stmt = _rel_helper( lhs=lhs_ident, rhs=rhs_ident, @@ -683,7 +690,7 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None): self._ast.where.append(" AND ".join(stmts)) def build_query(self): - query = "" + query: str = "" if self._ast.lookup: query += self._ast.lookup @@ -710,12 +717,20 @@ def build_query(self): query += self._ast.with_clause query += " RETURN " + returned_items: list[str] = [] if self._ast.return_clause: - query += self._ast.return_clause + returned_items.append(self._ast.return_clause) if self._ast.additional_return: - if self._ast.return_clause: - query += ", " - query += ", ".join(self._ast.additional_return) + returned_items += self._ast.additional_return + if hasattr(self.node_set, "_extra_results"): + for varname, vardef in self.node_set._extra_results.items(): + if varname in returned_items: + # We're about to override an existing variable, delete it first to + # avoid duplicate error + returned_items.remove(varname) + returned_items.append(f"{str(vardef)} AS {varname}") + + query += ", ".join(returned_items) if self._ast.order_by: query += " ORDER BY " @@ -754,7 +769,6 @@ async def _count(self): async def _contains(self, node_element_id): # inject id = into ast if not self._ast.return_clause: - print(self._ast.additional_return) self._ast.return_clause = self._ast.additional_return[0] ident = self._ast.return_clause place_holder = self._register_place_holder(ident + "_contains") @@ -881,6 +895,25 @@ class Optional: relation: str +@dataclass +class AggregatingFunction: + """Base aggregating function class.""" + + input_name: str + + +@dataclass +class Collect(AggregatingFunction): + """collect() function.""" + + distinct: bool = False + + def __str__(self): + if self.distinct: + return f"collect(DISTINCT {self.input_name})" + return f"collect({self.input_name})" + + class AsyncNodeSet(AsyncBaseSet): """ A class representing as set of nodes matching common query parameters @@ -908,6 +941,7 @@ def __init__(self, source): self.dont_match = {} self.relations_to_fetch: list = [] + self._extra_results: dict[str] = {} def __await__(self): return self.all().__await__() @@ -1062,18 +1096,56 @@ def order_by(self, *props): return self - def fetch_relations(self, *relation_names): + def _add_relations( + self, *relation_names, include_in_return=True, **aliased_relation_names + ): """Specify a set of relations to return.""" relations = [] - for relation_name in relation_names: - if isinstance(relation_name, Optional): - item = {"path": relation_name.relation, "optional": True} + + def register_relation_to_fetch(relation_def: Any, alias: str = None): + if isinstance(relation_def, Optional): + item = {"path": relation_def.relation, "optional": True} else: - item = {"path": relation_name} + item = {"path": relation_def} + item["include_in_return"] = include_in_return + if alias: + item["alias"] = alias relations.append(item) + + for relation_name in relation_names: + register_relation_to_fetch(relation_name) + for alias, relation_def in aliased_relation_names.items(): + register_relation_to_fetch(relation_def, alias) + self.relations_to_fetch = relations return self + def fetch_relations(self, *relation_names, **aliased_relation_names): + """Specify a set of relations to traverse and return.""" + return self._add_relations(*relation_names, **aliased_relation_names) + + def traverse_relations(self, *relation_names, **aliased_relation_names): + """Specify a set of relations to traverse only.""" + return self._add_relations( + *relation_names, include_in_return=False, **aliased_relation_names + ) + + def annotate(self, *vars, **aliased_vars): + """Annotate node set results with extra variables.""" + + def register_extra_var(vardef, varname: str = None): + if isinstance(vardef, AggregatingFunction): + self._extra_results[varname if varname else vardef.input_name] = vardef + else: + raise NotImplementedError + + for vardef in vars: + register_extra_var(vardef) + for varname, vardef in aliased_vars.items(): + register_extra_var(vardef, varname) + + return self + class AsyncTraversal(AsyncBaseSet): """ diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 928842c2..8f163c53 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -2,7 +2,7 @@ import re from collections import defaultdict from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase @@ -507,7 +507,7 @@ def build_traversal(self, traversal): return traversal_ident - def _additional_return(self, name): + def _additional_return(self, name: str): if name not in self._ast.additional_return and name != self._ast.return_clause: self._ast.additional_return.append(name) @@ -515,7 +515,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: path: str = relation["path"] stmt: str = "" source_class_iterator = source_class - for index, part in enumerate(path.split("__")): + parts = path.split("__") + for index, part in enumerate(parts): relationship = getattr(source_class_iterator, part) # build source if "node_class" not in relationship.definition: @@ -523,11 +524,16 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: rhs_label = relationship.definition["node_class"].__label__ rel_reference = f'{relationship.definition["node_class"]}_{part}' self._node_counters[rel_reference] += 1 - rhs_name = ( - f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}" - ) + if index + 1 == len(parts) and "alias" in relation: + # If an alias is defined, use it to store the last hop in the path + rhs_name = relation["alias"] + else: + rhs_name = ( + f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}" + ) rhs_ident = f"{rhs_name}:{rhs_label}" - self._additional_return(rhs_name) + if relation["include_in_return"]: + self._additional_return(rhs_name) if not stmt: lhs_label = source_class_iterator.__label__ lhs_name = lhs_label.lower() @@ -537,13 +543,14 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: # contains the primary node so _contains() works # as usual self._ast.return_clause = lhs_name - else: + elif relation["include_in_return"]: self._additional_return(lhs_name) else: lhs_ident = stmt rel_ident = self.create_ident() - self._additional_return(rel_ident) + if relation["include_in_return"]: + self._additional_return(rel_ident) stmt = _rel_helper( lhs=lhs_ident, rhs=rhs_ident, @@ -683,7 +690,7 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None): self._ast.where.append(" AND ".join(stmts)) def build_query(self): - query = "" + query: str = "" if self._ast.lookup: query += self._ast.lookup @@ -710,12 +717,20 @@ def build_query(self): query += self._ast.with_clause query += " RETURN " + returned_items: list[str] = [] if self._ast.return_clause: - query += self._ast.return_clause + returned_items.append(self._ast.return_clause) if self._ast.additional_return: - if self._ast.return_clause: - query += ", " - query += ", ".join(self._ast.additional_return) + returned_items += self._ast.additional_return + if hasattr(self.node_set, "_extra_results"): + for varname, vardef in self.node_set._extra_results.items(): + if varname in returned_items: + # We're about to override an existing variable, delete it first to + # avoid duplicate error + returned_items.remove(varname) + returned_items.append(f"{str(vardef)} AS {varname}") + + query += ", ".join(returned_items) if self._ast.order_by: query += " ORDER BY " @@ -754,7 +769,6 @@ def _count(self): def _contains(self, node_element_id): # inject id = into ast if not self._ast.return_clause: - print(self._ast.additional_return) self._ast.return_clause = self._ast.additional_return[0] ident = self._ast.return_clause place_holder = self._register_place_holder(ident + "_contains") @@ -877,6 +891,25 @@ class Optional: relation: str +@dataclass +class AggregatingFunction: + """Base aggregating function class.""" + + input_name: str + + +@dataclass +class Collect(AggregatingFunction): + """collect() function.""" + + distinct: bool = False + + def __str__(self): + if self.distinct: + return f"collect(DISTINCT {self.input_name})" + return f"collect({self.input_name})" + + class NodeSet(BaseSet): """ A class representing as set of nodes matching common query parameters @@ -904,6 +937,7 @@ def __init__(self, source): self.dont_match = {} self.relations_to_fetch: list = [] + self._extra_results: dict[str] = {} def __await__(self): return self.all().__await__() @@ -1058,18 +1092,56 @@ def order_by(self, *props): return self - def fetch_relations(self, *relation_names): + def _add_relations( + self, *relation_names, include_in_return=True, **aliased_relation_names + ): """Specify a set of relations to return.""" relations = [] - for relation_name in relation_names: - if isinstance(relation_name, Optional): - item = {"path": relation_name.relation, "optional": True} + + def register_relation_to_fetch(relation_def: Any, alias: str = None): + if isinstance(relation_def, Optional): + item = {"path": relation_def.relation, "optional": True} else: - item = {"path": relation_name} + item = {"path": relation_def} + item["include_in_return"] = include_in_return + if alias: + item["alias"] = alias relations.append(item) + + for relation_name in relation_names: + register_relation_to_fetch(relation_name) + for alias, relation_def in aliased_relation_names.items(): + register_relation_to_fetch(relation_def, alias) + self.relations_to_fetch = relations return self + def fetch_relations(self, *relation_names, **aliased_relation_names): + """Specify a set of relations to traverse and return.""" + return self._add_relations(*relation_names, **aliased_relation_names) + + def traverse_relations(self, *relation_names, **aliased_relation_names): + """Specify a set of relations to traverse only.""" + return self._add_relations( + *relation_names, include_in_return=False, **aliased_relation_names + ) + + def annotate(self, *vars, **aliased_vars): + """Annotate node set results with extra variables.""" + + def register_extra_var(vardef, varname: str = None): + if isinstance(vardef, AggregatingFunction): + self._extra_results[varname if varname else vardef.input_name] = vardef + else: + raise NotImplementedError + + for vardef in vars: + register_extra_var(vardef) + for varname, vardef in aliased_vars.items(): + register_extra_var(vardef, varname) + + return self + class Traversal(BaseSet): """ diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index e3195448..27406f2d 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -15,12 +15,14 @@ Q, StringProperty, UniqueIdProperty, + adb, ) from neomodel._async_compat.util import AsyncUtil from neomodel.async_.match import ( AsyncNodeSet, AsyncQueryBuilder, AsyncTraversal, + Collect, Optional, ) from neomodel.exceptions import MultipleNodesReturned, RelationshipClassNotDefined @@ -595,6 +597,46 @@ async def test_fetch_relations(): ) +@mark_async_test +async def test_annotate_and_collect(): + # Clean DB before we start anything... + await adb.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = await Species(name="Arabica").save() + robusta = await Species(name="Robusta").save() + nescafe = await Coffee(name="Nescafe 1002", price=99).save() + nescafe_gold = await Coffee(name="Nescafe 1003", price=11).save() + + tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + await nescafe.suppliers.connect(tesco) + await nescafe_gold.suppliers.connect(tesco) + await nescafe.species.connect(arabica) + await nescafe_gold.species.connect(robusta) + await nescafe_gold.species.connect(arabica) + + result = ( + await Supplier.nodes.traverse_relations(species="coffees__species") + .annotate(Collect("species")) + .all() + ) + assert len(result) == 1 + assert len(result[0][1][0]) == 3 # 3 species must be there (with 2 duplicates) + + result = ( + await Supplier.nodes.traverse_relations(species="coffees__species") + .annotate(Collect("species", distinct=True)) + .all() + ) + assert len(result[0][1][0]) == 2 # 2 species must be there + + result = ( + await Supplier.nodes.traverse_relations(species="coffees__species") + .annotate(all_species=Collect("species", distinct=True)) + .all() + ) + assert len(result[0][1][0]) == 2 # 2 species must be there + + @mark_async_test async def test_issue_795(): jim = await PersonX(name="Jim", age=3).save() # Create diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 170a7363..f872dad4 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -15,10 +15,11 @@ StructuredNode, StructuredRel, UniqueIdProperty, + db, ) from neomodel._async_compat.util import Util from neomodel.exceptions import MultipleNodesReturned, RelationshipClassNotDefined -from neomodel.sync_.match import NodeSet, Optional, QueryBuilder, Traversal +from neomodel.sync_.match import Collect, NodeSet, Optional, QueryBuilder, Traversal class SupplierRel(StructuredRel): @@ -584,6 +585,46 @@ def test_fetch_relations(): ) +@mark_sync_test +def test_annotate_and_collect(): + # Clean DB before we start anything... + db.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = Species(name="Arabica").save() + robusta = Species(name="Robusta").save() + nescafe = Coffee(name="Nescafe 1002", price=99).save() + nescafe_gold = Coffee(name="Nescafe 1003", price=11).save() + + tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + nescafe.suppliers.connect(tesco) + nescafe_gold.suppliers.connect(tesco) + nescafe.species.connect(arabica) + nescafe_gold.species.connect(robusta) + nescafe_gold.species.connect(arabica) + + result = ( + Supplier.nodes.traverse_relations(species="coffees__species") + .annotate(Collect("species")) + .all() + ) + assert len(result) == 1 + assert len(result[0][1][0]) == 3 # 3 species must be there (with 2 duplicates) + + result = ( + Supplier.nodes.traverse_relations(species="coffees__species") + .annotate(Collect("species", distinct=True)) + .all() + ) + assert len(result[0][1][0]) == 2 # 2 species must be there + + result = ( + Supplier.nodes.traverse_relations(species="coffees__species") + .annotate(all_species=Collect("species", distinct=True)) + .all() + ) + assert len(result[0][1][0]) == 2 # 2 species must be there + + @mark_sync_test def test_issue_795(): jim = PersonX(name="Jim", age=3).save() # Create From 95566a1b53ba701c3038ec404c8864b43a6ffc57 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 20 Sep 2024 17:16:12 +0200 Subject: [PATCH 02/42] Prepare rc branch --- doc/source/configuration.rst | 2 +- neomodel/_version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/source/configuration.rst b/doc/source/configuration.rst index e8c0a38b..511318c4 100644 --- a/doc/source/configuration.rst +++ b/doc/source/configuration.rst @@ -32,7 +32,7 @@ Adjust driver configuration - these options are only available for this connecti config.MAX_TRANSACTION_RETRY_TIME = 30.0 # default config.RESOLVER = None # default config.TRUST = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES # default - config.USER_AGENT = neomodel/v5.3.2 # default + config.USER_AGENT = neomodel/v5.4.0 # default Setting the database name, if different from the default one:: diff --git a/neomodel/_version.py b/neomodel/_version.py index 07f0e9e2..fc30498f 100644 --- a/neomodel/_version.py +++ b/neomodel/_version.py @@ -1 +1 @@ -__version__ = "5.3.2" +__version__ = "5.4.0" From 8c7d4a9bbdc1eb64a856d8313b33d93da6d92fb6 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Mon, 23 Sep 2024 16:19:33 +0200 Subject: [PATCH 03/42] Added method to resolve a subgraph from a fetch_relations() call. --- docker-scripts/docker-neo4j.sh | 3 +- neomodel/async_/match.py | 87 +++++++++++++++++++++++++++++++-- neomodel/sync_/match.py | 89 ++++++++++++++++++++++++++++++++-- test/async_/test_match_api.py | 34 +++++++++++++ test/sync_/test_match_api.py | 34 +++++++++++++ 5 files changed, 240 insertions(+), 7 deletions(-) diff --git a/docker-scripts/docker-neo4j.sh b/docker-scripts/docker-neo4j.sh index 99aabfff..6b146c95 100644 --- a/docker-scripts/docker-neo4j.sh +++ b/docker-scripts/docker-neo4j.sh @@ -5,4 +5,5 @@ docker run \ --env NEO4J_AUTH=neo4j/foobarbaz \ --env NEO4J_ACCEPT_LICENSE_AGREEMENT=yes \ --env NEO4JLABS_PLUGINS='["apoc"]' \ - neo4j:$1 \ No newline at end of file + --rm \ + neo4j:$1 diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index c203dcc2..2796f504 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -5,6 +5,7 @@ from typing import Any, Optional from neomodel.async_.core import AsyncStructuredNode, adb +from neomodel.async_.relationship import AsyncStructuredRel from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase from neomodel.properties import AliasProperty, ArrayProperty @@ -414,16 +415,18 @@ def __init__( self.lookup = lookup self.additional_return = additional_return if additional_return else [] self.is_count = is_count + self.subgraph: dict = {} class AsyncQueryBuilder: - def __init__(self, node_set): + def __init__(self, node_set, with_subgraph: bool = False): self.node_set = node_set self._ast = QueryAST() self._query_params = {} self._place_holder_registry = {} self._ident_count = 0 self._node_counters = defaultdict(int) + self._with_subgraph: bool = with_subgraph async def build_ast(self): if hasattr(self.node_set, "relations_to_fetch"): @@ -516,6 +519,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: stmt: str = "" source_class_iterator = source_class parts = path.split("__") + if self._with_subgraph: + subgraph = self._ast.subgraph for index, part in enumerate(parts): relationship = getattr(source_class_iterator, part) # build source @@ -549,6 +554,13 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: lhs_ident = stmt rel_ident = self.create_ident() + if self._with_subgraph and part not in self._ast.subgraph: + subgraph[part] = { + "target": relationship.definition["node_class"], + "children": {}, + "variable_name": rhs_name, + "rel_variable_name": rel_ident, + } if relation["include_in_return"]: self._additional_return(rel_ident) stmt = _rel_helper( @@ -559,6 +571,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: relation_type=relationship.definition["relation_type"], ) source_class_iterator = relationship.definition["node_class"] + if self._with_subgraph: + subgraph = subgraph[part]["children"] if relation.get("optional"): self._ast.optional_match.append(stmt) @@ -778,7 +792,7 @@ async def _contains(self, node_element_id): self._query_params[place_holder] = node_element_id return await self._count() >= 1 - async def _execute(self, lazy=False): + async def _execute(self, lazy: bool = False, dict_output: bool = False): if lazy: # inject id() into return or return_set if self._ast.return_clause: @@ -791,9 +805,13 @@ async def _execute(self, lazy=False): for item in self._ast.additional_return ] query = self.build_query() - results, _ = await adb.cypher_query( + results, prop_names = await adb.cypher_query( query, self._query_params, resolve_objects=True ) + if dict_output: + for item in results: + yield dict(zip(prop_names, item)) + return # The following is not as elegant as it could be but had to be copied from the # version prior to cypher_query with the resolve_objects capability. # It seems that certain calls are only supposed to be focusing to the first @@ -1146,6 +1164,69 @@ def register_extra_var(vardef, varname: str = None): return self + def _to_subgraph(self, root_node, other_nodes, subgraph): + """Recursive method to build root_node's relation graph from subgraph.""" + root_node._relations = {} + for name, relation_def in subgraph.items(): + for var_name, node in other_nodes.items(): + if ( + var_name + not in [ + relation_def["variable_name"], + relation_def["rel_variable_name"], + ] + or node is None + ): + continue + if isinstance(node, list): + if len(node) > 0 and isinstance(node[0], AsyncStructuredRel): + name += "_relationship" + root_node._relations[name] = [] + for item in node: + root_node._relations[name].append( + self._to_subgraph( + item, other_nodes, relation_def["children"] + ) + ) + else: + if isinstance(node, AsyncStructuredRel): + name += "_relationship" + root_node._relations[name] = self._to_subgraph( + node, other_nodes, relation_def["children"] + ) + + return root_node + + async def resolve_subgraph(self) -> list: + """ + Convert every result contained in this node set to a subgraph. + + By default, we receive results from neomodel as a list of + nodes without the hierarchy. This method tries to rebuild this + hierarchy without overriding anything in the node, that's why + we use a dedicated property to store node's relations. + + """ + results: list = [] + qbuilder = self.query_cls(self, with_subgraph=True) + await qbuilder.build_ast() + all_nodes = qbuilder._execute(dict_output=True) + other_nodes = {} + root_node = None + async for row in all_nodes: + for name, node in row.items(): + if node.__class__ is self.source and "_" not in name: + root_node = node + else: + if isinstance(node, list) and isinstance(node[0], list): + other_nodes[name] = node[0] + else: + other_nodes[name] = node + results.append( + self._to_subgraph(root_node, other_nodes, qbuilder._ast.subgraph) + ) + return results + class AsyncTraversal(AsyncBaseSet): """ diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 8f163c53..8221c979 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -8,6 +8,7 @@ from neomodel.match_q import Q, QBase from neomodel.properties import AliasProperty, ArrayProperty from neomodel.sync_.core import StructuredNode, db +from neomodel.sync_.relationship import StructuredRel from neomodel.util import INCOMING, OUTGOING @@ -414,16 +415,18 @@ def __init__( self.lookup = lookup self.additional_return = additional_return if additional_return else [] self.is_count = is_count + self.subgraph: dict = {} class QueryBuilder: - def __init__(self, node_set): + def __init__(self, node_set, with_subgraph: bool = False): self.node_set = node_set self._ast = QueryAST() self._query_params = {} self._place_holder_registry = {} self._ident_count = 0 self._node_counters = defaultdict(int) + self._with_subgraph: bool = with_subgraph def build_ast(self): if hasattr(self.node_set, "relations_to_fetch"): @@ -516,6 +519,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: stmt: str = "" source_class_iterator = source_class parts = path.split("__") + if self._with_subgraph: + subgraph = self._ast.subgraph for index, part in enumerate(parts): relationship = getattr(source_class_iterator, part) # build source @@ -549,6 +554,13 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: lhs_ident = stmt rel_ident = self.create_ident() + if self._with_subgraph and part not in self._ast.subgraph: + subgraph[part] = { + "target": relationship.definition["node_class"], + "children": {}, + "variable_name": rhs_name, + "rel_variable_name": rel_ident, + } if relation["include_in_return"]: self._additional_return(rel_ident) stmt = _rel_helper( @@ -559,6 +571,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: relation_type=relationship.definition["relation_type"], ) source_class_iterator = relationship.definition["node_class"] + if self._with_subgraph: + subgraph = subgraph[part]["children"] if relation.get("optional"): self._ast.optional_match.append(stmt) @@ -776,7 +790,7 @@ def _contains(self, node_element_id): self._query_params[place_holder] = node_element_id return self._count() >= 1 - def _execute(self, lazy=False): + def _execute(self, lazy: bool = False, dict_output: bool = False): if lazy: # inject id() into return or return_set if self._ast.return_clause: @@ -789,7 +803,13 @@ def _execute(self, lazy=False): for item in self._ast.additional_return ] query = self.build_query() - results, _ = db.cypher_query(query, self._query_params, resolve_objects=True) + results, prop_names = db.cypher_query( + query, self._query_params, resolve_objects=True + ) + if dict_output: + for item in results: + yield dict(zip(prop_names, item)) + return # The following is not as elegant as it could be but had to be copied from the # version prior to cypher_query with the resolve_objects capability. # It seems that certain calls are only supposed to be focusing to the first @@ -1142,6 +1162,69 @@ def register_extra_var(vardef, varname: str = None): return self + def _to_subgraph(self, root_node, other_nodes, subgraph): + """Recursive method to build root_node's relation graph from subgraph.""" + root_node._relations = {} + for name, relation_def in subgraph.items(): + for var_name, node in other_nodes.items(): + if ( + var_name + not in [ + relation_def["variable_name"], + relation_def["rel_variable_name"], + ] + or node is None + ): + continue + if isinstance(node, list): + if len(node) > 0 and isinstance(node[0], StructuredRel): + name += "_relationship" + root_node._relations[name] = [] + for item in node: + root_node._relations[name].append( + self._to_subgraph( + item, other_nodes, relation_def["children"] + ) + ) + else: + if isinstance(node, StructuredRel): + name += "_relationship" + root_node._relations[name] = self._to_subgraph( + node, other_nodes, relation_def["children"] + ) + + return root_node + + def resolve_subgraph(self) -> list: + """ + Convert every result contained in this node set to a subgraph. + + By default, we receive results from neomodel as a list of + nodes without the hierarchy. This method tries to rebuild this + hierarchy without overriding anything in the node, that's why + we use a dedicated property to store node's relations. + + """ + results: list = [] + qbuilder = self.query_cls(self, with_subgraph=True) + qbuilder.build_ast() + all_nodes = qbuilder._execute(dict_output=True) + other_nodes = {} + root_node = None + for row in all_nodes: + for name, node in row.items(): + if node.__class__ is self.source and "_" not in name: + root_node = node + else: + if isinstance(node, list) and isinstance(node[0], list): + other_nodes[name] = node[0] + else: + other_nodes[name] = node + results.append( + self._to_subgraph(root_node, other_nodes, qbuilder._ast.subgraph) + ) + return results + class Traversal(BaseSet): """ diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 2175d604..ff8261ef 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -638,6 +638,40 @@ async def test_annotate_and_collect(): assert len(result[0][1][0]) == 2 # 2 species must be there +@mark_async_test +async def test_resolve_subgraph(): + # Clean DB before we start anything... + await adb.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = await Species(name="Arabica").save() + robusta = await Species(name="Robusta").save() + nescafe = await Coffee(name="Nescafe", price=99).save() + nescafe_gold = await Coffee(name="Nescafe Gold", price=11).save() + + tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + await nescafe.suppliers.connect(tesco) + await nescafe_gold.suppliers.connect(tesco) + await nescafe.species.connect(arabica) + await nescafe_gold.species.connect(robusta) + + result = await Supplier.nodes.fetch_relations("coffees__species").resolve_subgraph() + assert len(result) == 2 + + assert hasattr(result[0], "_relations") + assert "coffees" in result[0]._relations + coffees = result[0]._relations["coffees"] + assert hasattr(coffees, "_relations") + assert "species" in coffees._relations + assert robusta == coffees._relations["species"] + + assert hasattr(result[1], "_relations") + assert "coffees" in result[1]._relations + coffees = result[1]._relations["coffees"] + assert hasattr(coffees, "_relations") + assert "species" in coffees._relations + assert arabica == coffees._relations["species"] + + @mark_async_test async def test_issue_795(): jim = await PersonX(name="Jim", age=3).save() # Create diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index d22e4100..5aecf3cf 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -626,6 +626,40 @@ def test_annotate_and_collect(): assert len(result[0][1][0]) == 2 # 2 species must be there +@mark_sync_test +def test_resolve_subgraph(): + # Clean DB before we start anything... + db.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = Species(name="Arabica").save() + robusta = Species(name="Robusta").save() + nescafe = Coffee(name="Nescafe", price=99).save() + nescafe_gold = Coffee(name="Nescafe Gold", price=11).save() + + tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + nescafe.suppliers.connect(tesco) + nescafe_gold.suppliers.connect(tesco) + nescafe.species.connect(arabica) + nescafe_gold.species.connect(robusta) + + result = Supplier.nodes.fetch_relations("coffees__species").resolve_subgraph() + assert len(result) == 2 + + assert hasattr(result[0], "_relations") + assert "coffees" in result[0]._relations + coffees = result[0]._relations["coffees"] + assert hasattr(coffees, "_relations") + assert "species" in coffees._relations + assert robusta == coffees._relations["species"] + + assert hasattr(result[1], "_relations") + assert "coffees" in result[1]._relations + coffees = result[1]._relations["coffees"] + assert hasattr(coffees, "_relations") + assert "species" in coffees._relations + assert arabica == coffees._relations["species"] + + @mark_sync_test def test_issue_795(): jim = PersonX(name="Jim", age=3).save() # Create From fefabce3fbf795d942e7bbf08c08ac493c8d9184 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Mon, 23 Sep 2024 17:26:47 +0200 Subject: [PATCH 04/42] Added constraints to limit use cases of resolve_subgraph() --- neomodel/async_/match.py | 61 +++++++++++++++++++++-------------- neomodel/sync_/match.py | 61 +++++++++++++++++++++-------------- test/async_/test_match_api.py | 19 +++++++++++ test/sync_/test_match_api.py | 19 +++++++++++ 4 files changed, 110 insertions(+), 50 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 2796f504..03033415 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -1114,39 +1114,42 @@ def order_by(self, *props): return self - def _add_relations( - self, *relation_names, include_in_return=True, **aliased_relation_names + def _register_relation_to_fetch( + self, relation_def: Any, alias: str = None, include_in_return: bool = True ): - """Specify a set of relations to return.""" - relations = [] - - def register_relation_to_fetch(relation_def: Any, alias: str = None): - if isinstance(relation_def, Optional): - item = {"path": relation_def.relation, "optional": True} - else: - item = {"path": relation_def} - item["include_in_return"] = include_in_return - if alias: - item["alias"] = alias - relations.append(item) + if isinstance(relation_def, Optional): + item = {"path": relation_def.relation, "optional": True} + else: + item = {"path": relation_def} + item["include_in_return"] = include_in_return + if alias: + item["alias"] = alias + return item + def fetch_relations(self, *relation_names): + """Specify a set of relations to traverse and return.""" + relations = [] for relation_name in relation_names: - register_relation_to_fetch(relation_name) - for alias, relation_def in aliased_relation_names.items(): - register_relation_to_fetch(relation_def, alias) - + relations.append(self._register_relation_to_fetch(relation_name)) self.relations_to_fetch = relations return self - def fetch_relations(self, *relation_names, **aliased_relation_names): - """Specify a set of relations to traverse and return.""" - return self._add_relations(*relation_names, **aliased_relation_names) - def traverse_relations(self, *relation_names, **aliased_relation_names): """Specify a set of relations to traverse only.""" - return self._add_relations( - *relation_names, include_in_return=False, **aliased_relation_names - ) + relations = [] + for relation_name in relation_names: + relations.append( + self._register_relation_to_fetch(relation_name, include_in_return=False) + ) + for alias, relation_def in aliased_relation_names.items(): + relations.append( + self._register_relation_to_fetch( + relation_def, alias, include_in_return=False + ) + ) + + self.relations_to_fetch = relations + return self def annotate(self, *vars, **aliased_vars): """Annotate node set results with extra variables.""" @@ -1207,6 +1210,14 @@ async def resolve_subgraph(self) -> list: we use a dedicated property to store node's relations. """ + if not self.relations_to_fetch: + raise RuntimeError( + "Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()." + ) + if not self.relations_to_fetch[0]["include_in_return"]: + raise NotImplementedError( + "You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead." + ) results: list = [] qbuilder = self.query_cls(self, with_subgraph=True) await qbuilder.build_ast() diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 8221c979..0f8b044e 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -1112,39 +1112,42 @@ def order_by(self, *props): return self - def _add_relations( - self, *relation_names, include_in_return=True, **aliased_relation_names + def _register_relation_to_fetch( + self, relation_def: Any, alias: str = None, include_in_return: bool = True ): - """Specify a set of relations to return.""" - relations = [] - - def register_relation_to_fetch(relation_def: Any, alias: str = None): - if isinstance(relation_def, Optional): - item = {"path": relation_def.relation, "optional": True} - else: - item = {"path": relation_def} - item["include_in_return"] = include_in_return - if alias: - item["alias"] = alias - relations.append(item) + if isinstance(relation_def, Optional): + item = {"path": relation_def.relation, "optional": True} + else: + item = {"path": relation_def} + item["include_in_return"] = include_in_return + if alias: + item["alias"] = alias + return item + def fetch_relations(self, *relation_names): + """Specify a set of relations to traverse and return.""" + relations = [] for relation_name in relation_names: - register_relation_to_fetch(relation_name) - for alias, relation_def in aliased_relation_names.items(): - register_relation_to_fetch(relation_def, alias) - + relations.append(self._register_relation_to_fetch(relation_name)) self.relations_to_fetch = relations return self - def fetch_relations(self, *relation_names, **aliased_relation_names): - """Specify a set of relations to traverse and return.""" - return self._add_relations(*relation_names, **aliased_relation_names) - def traverse_relations(self, *relation_names, **aliased_relation_names): """Specify a set of relations to traverse only.""" - return self._add_relations( - *relation_names, include_in_return=False, **aliased_relation_names - ) + relations = [] + for relation_name in relation_names: + relations.append( + self._register_relation_to_fetch(relation_name, include_in_return=False) + ) + for alias, relation_def in aliased_relation_names.items(): + relations.append( + self._register_relation_to_fetch( + relation_def, alias, include_in_return=False + ) + ) + + self.relations_to_fetch = relations + return self def annotate(self, *vars, **aliased_vars): """Annotate node set results with extra variables.""" @@ -1205,6 +1208,14 @@ def resolve_subgraph(self) -> list: we use a dedicated property to store node's relations. """ + if not self.relations_to_fetch: + raise RuntimeError( + "Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()." + ) + if not self.relations_to_fetch[0]["include_in_return"]: + raise NotImplementedError( + "You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead." + ) results: list = [] qbuilder = self.query_cls(self, with_subgraph=True) qbuilder.build_ast() diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index ff8261ef..216fbe7e 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -1,3 +1,4 @@ +import re from datetime import datetime from test._async_compat import mark_async_test @@ -654,6 +655,24 @@ async def test_resolve_subgraph(): await nescafe.species.connect(arabica) await nescafe_gold.species.connect(robusta) + with raises( + RuntimeError, + match=re.escape( + "Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()." + ), + ): + result = await Supplier.nodes.resolve_subgraph() + + with raises( + NotImplementedError, + match=re.escape( + "You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead." + ), + ): + result = await Supplier.nodes.traverse_relations( + "coffees__species" + ).resolve_subgraph() + result = await Supplier.nodes.fetch_relations("coffees__species").resolve_subgraph() assert len(result) == 2 diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 5aecf3cf..d015776a 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -1,3 +1,4 @@ +import re from datetime import datetime from test._async_compat import mark_sync_test @@ -642,6 +643,24 @@ def test_resolve_subgraph(): nescafe.species.connect(arabica) nescafe_gold.species.connect(robusta) + with raises( + RuntimeError, + match=re.escape( + "Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()." + ), + ): + result = Supplier.nodes.resolve_subgraph() + + with raises( + NotImplementedError, + match=re.escape( + "You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead." + ), + ): + result = Supplier.nodes.traverse_relations( + "coffees__species" + ).resolve_subgraph() + result = Supplier.nodes.fetch_relations("coffees__species").resolve_subgraph() assert len(result) == 2 From 065f1bd230ff91e0408a03b222b3f86ab6f9102b Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Mon, 23 Sep 2024 17:32:49 +0200 Subject: [PATCH 05/42] Cannot ensure output order so make sure tests are running. --- test/async_/test_match_api.py | 2 -- test/sync_/test_match_api.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 216fbe7e..1c61589d 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -681,14 +681,12 @@ async def test_resolve_subgraph(): coffees = result[0]._relations["coffees"] assert hasattr(coffees, "_relations") assert "species" in coffees._relations - assert robusta == coffees._relations["species"] assert hasattr(result[1], "_relations") assert "coffees" in result[1]._relations coffees = result[1]._relations["coffees"] assert hasattr(coffees, "_relations") assert "species" in coffees._relations - assert arabica == coffees._relations["species"] @mark_async_test diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index d015776a..9f24afbf 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -669,14 +669,12 @@ def test_resolve_subgraph(): coffees = result[0]._relations["coffees"] assert hasattr(coffees, "_relations") assert "species" in coffees._relations - assert robusta == coffees._relations["species"] assert hasattr(result[1], "_relations") assert "coffees" in result[1]._relations coffees = result[1]._relations["coffees"] assert hasattr(coffees, "_relations") assert "species" in coffees._relations - assert arabica == coffees._relations["species"] @mark_sync_test From 40a60fbe4f325fe3f8cb6f963897fba51eb8f821 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Mon, 23 Sep 2024 17:36:30 +0200 Subject: [PATCH 06/42] Added test to cover use case with Optional match. --- test/async_/test_match_api.py | 27 +++++++++++++++++++++++++++ test/sync_/test_match_api.py | 27 +++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 1c61589d..b3546ef9 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -689,6 +689,33 @@ async def test_resolve_subgraph(): assert "species" in coffees._relations +@mark_async_test +async def test_resolve_subgraph_optional(): + # Clean DB before we start anything... + await adb.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = await Species(name="Arabica").save() + nescafe = await Coffee(name="Nescafe", price=99).save() + nescafe_gold = await Coffee(name="Nescafe Gold", price=11).save() + + tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + await nescafe.suppliers.connect(tesco) + await nescafe_gold.suppliers.connect(tesco) + await nescafe.species.connect(arabica) + + result = await Supplier.nodes.fetch_relations( + Optional("coffees__species") + ).resolve_subgraph() + assert len(result) == 1 + + assert hasattr(result[0], "_relations") + assert "coffees" in result[0]._relations + coffees = result[0]._relations["coffees"] + assert hasattr(coffees, "_relations") + assert "species" in coffees._relations + assert coffees._relations["species"] == arabica + + @mark_async_test async def test_issue_795(): jim = await PersonX(name="Jim", age=3).save() # Create diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 9f24afbf..ab843639 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -677,6 +677,33 @@ def test_resolve_subgraph(): assert "species" in coffees._relations +@mark_sync_test +def test_resolve_subgraph_optional(): + # Clean DB before we start anything... + db.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = Species(name="Arabica").save() + nescafe = Coffee(name="Nescafe", price=99).save() + nescafe_gold = Coffee(name="Nescafe Gold", price=11).save() + + tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + nescafe.suppliers.connect(tesco) + nescafe_gold.suppliers.connect(tesco) + nescafe.species.connect(arabica) + + result = Supplier.nodes.fetch_relations( + Optional("coffees__species") + ).resolve_subgraph() + assert len(result) == 1 + + assert hasattr(result[0], "_relations") + assert "coffees" in result[0]._relations + coffees = result[0]._relations["coffees"] + assert hasattr(coffees, "_relations") + assert "species" in coffees._relations + assert coffees._relations["species"] == arabica + + @mark_sync_test def test_issue_795(): jim = PersonX(name="Jim", age=3).save() # Create From 55e04bcaa57114434ba354bb54dea0deceac3a02 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Wed, 2 Oct 2024 13:06:58 +0200 Subject: [PATCH 07/42] Added possibility to specify subqueries. --- neomodel/async_/match.py | 48 ++++++++++++++++++++++++++++++----- neomodel/sync_/match.py | 46 ++++++++++++++++++++++++++++----- test/async_/test_match_api.py | 36 ++++++++++++++++++++++++++ test/sync_/test_match_api.py | 36 ++++++++++++++++++++++++++ 4 files changed, 152 insertions(+), 14 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 03033415..f3390340 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -419,7 +419,9 @@ def __init__( class AsyncQueryBuilder: - def __init__(self, node_set, with_subgraph: bool = False): + def __init__( + self, node_set, with_subgraph: bool = False, subquery_context: bool = False + ): self.node_set = node_set self._ast = QueryAST() self._query_params = {} @@ -427,6 +429,7 @@ def __init__(self, node_set, with_subgraph: bool = False): self._ident_count = 0 self._node_counters = defaultdict(int) self._with_subgraph: bool = with_subgraph + self._subquery_context: bool = subquery_context async def build_ast(self): if hasattr(self.node_set, "relations_to_fetch"): @@ -442,7 +445,7 @@ async def build_ast(self): return self - async def build_source(self, source): + async def build_source(self, source) -> str: if isinstance(source, AsyncTraversal): return await self.build_traversal(source) if isinstance(source, AsyncNodeSet): @@ -548,6 +551,9 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: # contains the primary node so _contains() works # as usual self._ast.return_clause = lhs_name + if self._subquery_context: + # Don't include label in identifier if we are in a subquery + lhs_ident = lhs_name elif relation["include_in_return"]: self._additional_return(lhs_name) else: @@ -594,7 +600,7 @@ async def build_node(self, node): self._ast.result_class = node.__class__ return ident - def build_label(self, ident, cls): + def build_label(self, ident, cls) -> str: """ match nodes by a label """ @@ -703,7 +709,7 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None): self._ast.where.append(" AND ".join(stmts)) - def build_query(self): + def build_query(self) -> str: query: str = "" if self._ast.lookup: @@ -730,9 +736,15 @@ def build_query(self): query += " WITH " query += self._ast.with_clause - query += " RETURN " returned_items: list[str] = [] - if self._ast.return_clause: + if hasattr(self.node_set, "_subqueries"): + for subquery, return_set in self.node_set._subqueries: + outer_primary_var: str = self._ast.return_clause + query += f" CALL {{ WITH {outer_primary_var} {subquery} }} " + returned_items += return_set + + query += " RETURN " + if self._ast.return_clause and not self._subquery_context: returned_items.append(self._ast.return_clause) if self._ast.additional_return: returned_items += self._ast.additional_return @@ -960,6 +972,7 @@ def __init__(self, source): self.relations_to_fetch: list = [] self._extra_results: dict[str] = {} + self._subqueries: list[tuple(str, list[str])] = [] def __await__(self): return self.all().__await__() @@ -1238,6 +1251,27 @@ async def resolve_subgraph(self) -> list: ) return results + async def subquery( + self, nodeset: "AsyncNodeSet", return_set: list[str] + ) -> "AsyncNodeSet": + """Add a subquery to this node set. + + A subquery is a regular cypher query but executed within the context of a CALL + statement. Such query will generally fetch additional variables which must be + declared inside return_set variable in order to be included in the final RETURN + statement. + """ + qbuilder = await nodeset.query_cls(nodeset, subquery_context=True).build_ast() + for var in return_set: + if ( + var != qbuilder._ast.return_clause + and var not in qbuilder._ast.additional_return + and var not in nodeset._extra_results + ): + raise RuntimeError(f"Variable '{var}' is not returned by subquery.") + self._subqueries.append((qbuilder.build_query(), return_set)) + return self + class AsyncTraversal(AsyncBaseSet): """ @@ -1251,7 +1285,7 @@ class AsyncTraversal(AsyncBaseSet): :type name: :class:`str` :param definition: A relationship definition that most certainly deserves a documentation here. - :type defintion: :class:`dict` + :type definition: :class:`dict` """ def __await__(self): diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 0f8b044e..4ea10560 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -419,7 +419,9 @@ def __init__( class QueryBuilder: - def __init__(self, node_set, with_subgraph: bool = False): + def __init__( + self, node_set, with_subgraph: bool = False, subquery_context: bool = False + ): self.node_set = node_set self._ast = QueryAST() self._query_params = {} @@ -427,6 +429,7 @@ def __init__(self, node_set, with_subgraph: bool = False): self._ident_count = 0 self._node_counters = defaultdict(int) self._with_subgraph: bool = with_subgraph + self._subquery_context: bool = subquery_context def build_ast(self): if hasattr(self.node_set, "relations_to_fetch"): @@ -442,7 +445,7 @@ def build_ast(self): return self - def build_source(self, source): + def build_source(self, source) -> str: if isinstance(source, Traversal): return self.build_traversal(source) if isinstance(source, NodeSet): @@ -548,6 +551,9 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: # contains the primary node so _contains() works # as usual self._ast.return_clause = lhs_name + if self._subquery_context: + # Don't include label in identifier if we are in a subquery + lhs_ident = lhs_name elif relation["include_in_return"]: self._additional_return(lhs_name) else: @@ -594,7 +600,7 @@ def build_node(self, node): self._ast.result_class = node.__class__ return ident - def build_label(self, ident, cls): + def build_label(self, ident, cls) -> str: """ match nodes by a label """ @@ -703,7 +709,7 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None): self._ast.where.append(" AND ".join(stmts)) - def build_query(self): + def build_query(self) -> str: query: str = "" if self._ast.lookup: @@ -730,9 +736,15 @@ def build_query(self): query += " WITH " query += self._ast.with_clause - query += " RETURN " returned_items: list[str] = [] - if self._ast.return_clause: + if hasattr(self.node_set, "_subqueries"): + for subquery, return_set in self.node_set._subqueries: + outer_primary_var: str = self._ast.return_clause + query += f" CALL {{ WITH {outer_primary_var} {subquery} }} " + returned_items += return_set + + query += " RETURN " + if self._ast.return_clause and not self._subquery_context: returned_items.append(self._ast.return_clause) if self._ast.additional_return: returned_items += self._ast.additional_return @@ -958,6 +970,7 @@ def __init__(self, source): self.relations_to_fetch: list = [] self._extra_results: dict[str] = {} + self._subqueries: list[tuple(str, list[str])] = [] def __await__(self): return self.all().__await__() @@ -1236,6 +1249,25 @@ def resolve_subgraph(self) -> list: ) return results + def subquery(self, nodeset: "NodeSet", return_set: list[str]) -> "NodeSet": + """Add a subquery to this node set. + + A subquery is a regular cypher query but executed within the context of a CALL + statement. Such query will generally fetch additional variables which must be + declared inside return_set variable in order to be included in the final RETURN + statement. + """ + qbuilder = nodeset.query_cls(nodeset, subquery_context=True).build_ast() + for var in return_set: + if ( + var != qbuilder._ast.return_clause + and var not in qbuilder._ast.additional_return + and var not in nodeset._extra_results + ): + raise RuntimeError(f"Variable '{var}' is not returned by subquery.") + self._subqueries.append((qbuilder.build_query(), return_set)) + return self + class Traversal(BaseSet): """ @@ -1249,7 +1281,7 @@ class Traversal(BaseSet): :type name: :class:`str` :param definition: A relationship definition that most certainly deserves a documentation here. - :type defintion: :class:`dict` + :type definition: :class:`dict` """ def __await__(self): diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index b3546ef9..c4ff24d9 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -716,6 +716,42 @@ async def test_resolve_subgraph_optional(): assert coffees._relations["species"] == arabica +@mark_async_test +async def test_subquery(): + # Clean DB before we start anything... + await adb.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = await Species(name="Arabica").save() + nescafe = await Coffee(name="Nescafe", price=99).save() + supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save() + supplier2 = await Supplier(name="Supplier 2", delivery_cost=20).save() + + await nescafe.suppliers.connect(supplier1) + await nescafe.suppliers.connect(supplier2) + await nescafe.species.connect(arabica) + + result = await Coffee.nodes.subquery( + Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( + supps=Collect("suppliers") + ), + ["supps"], + ) + result = await result.all() + assert len(result) == 1 + assert len(result[0][0][0]) == 2 + + with raises( + RuntimeError, + match=re.escape("Variable 'unknown' is not returned by subquery."), + ): + result = await Coffee.nodes.subquery( + Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( + supps=Collect("suppliers") + ), + ["unknown"], + ) + + @mark_async_test async def test_issue_795(): jim = await PersonX(name="Jim", age=3).save() # Create diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index ab843639..09e19bc1 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -704,6 +704,42 @@ def test_resolve_subgraph_optional(): assert coffees._relations["species"] == arabica +@mark_sync_test +def test_subquery(): + # Clean DB before we start anything... + db.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = Species(name="Arabica").save() + nescafe = Coffee(name="Nescafe", price=99).save() + supplier1 = Supplier(name="Supplier 1", delivery_cost=3).save() + supplier2 = Supplier(name="Supplier 2", delivery_cost=20).save() + + nescafe.suppliers.connect(supplier1) + nescafe.suppliers.connect(supplier2) + nescafe.species.connect(arabica) + + result = Coffee.nodes.subquery( + Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( + supps=Collect("suppliers") + ), + ["supps"], + ) + result = result.all() + assert len(result) == 1 + assert len(result[0][0][0]) == 2 + + with raises( + RuntimeError, + match=re.escape("Variable 'unknown' is not returned by subquery."), + ): + result = Coffee.nodes.subquery( + Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( + supps=Collect("suppliers") + ), + ["unknown"], + ) + + @mark_sync_test def test_issue_795(): jim = PersonX(name="Jim", age=3).save() # Create From 56f0ff0361e4d4e4f9d38d85f3be43e9942937dd Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Wed, 2 Oct 2024 13:13:26 +0200 Subject: [PATCH 08/42] Compat with python <= 3.8 --- neomodel/async_/match.py | 4 ++-- neomodel/sync_/match.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index f3390340..abe796e0 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -2,7 +2,7 @@ import re from collections import defaultdict from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, List, Optional from neomodel.async_.core import AsyncStructuredNode, adb from neomodel.async_.relationship import AsyncStructuredRel @@ -1252,7 +1252,7 @@ async def resolve_subgraph(self) -> list: return results async def subquery( - self, nodeset: "AsyncNodeSet", return_set: list[str] + self, nodeset: "AsyncNodeSet", return_set: List[str] ) -> "AsyncNodeSet": """Add a subquery to this node set. diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 4ea10560..4ba70e2f 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -2,7 +2,7 @@ import re from collections import defaultdict from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, List, Optional from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase @@ -1249,7 +1249,7 @@ def resolve_subgraph(self) -> list: ) return results - def subquery(self, nodeset: "NodeSet", return_set: list[str]) -> "NodeSet": + def subquery(self, nodeset: "NodeSet", return_set: List[str]) -> "NodeSet": """Add a subquery to this node set. A subquery is a regular cypher query but executed within the context of a CALL From c73f5233a1b4185a038051761d7a566ab2720e98 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Wed, 2 Oct 2024 14:13:19 +0200 Subject: [PATCH 09/42] Added support for last() scalar function --- neomodel/async_/match.py | 21 ++++++++++++++++++--- neomodel/sync_/match.py | 21 ++++++++++++++++++--- test/async_/test_match_api.py | 6 ++++-- test/sync_/test_match_api.py | 14 +++++++++++--- 4 files changed, 51 insertions(+), 11 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index abe796e0..7048c86b 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -2,7 +2,7 @@ import re from collections import defaultdict from dataclasses import dataclass -from typing import Any, List, Optional +from typing import Any, List, Optional, Union from neomodel.async_.core import AsyncStructuredNode, adb from neomodel.async_.relationship import AsyncStructuredRel @@ -944,6 +944,21 @@ def __str__(self): return f"collect({self.input_name})" +@dataclass +class ScalarFunction: + """Base scalar function class.""" + + input_name: Union[str, AggregatingFunction] + + +@dataclass +class Last(ScalarFunction): + """last() function.""" + + def __str__(self) -> str: + return f"last({str(self.input_name)})" + + class AsyncNodeSet(AsyncBaseSet): """ A class representing as set of nodes matching common query parameters @@ -1167,8 +1182,8 @@ def traverse_relations(self, *relation_names, **aliased_relation_names): def annotate(self, *vars, **aliased_vars): """Annotate node set results with extra variables.""" - def register_extra_var(vardef, varname: str = None): - if isinstance(vardef, AggregatingFunction): + def register_extra_var(vardef, varname: Union[str, None] = None): + if isinstance(vardef, (AggregatingFunction, ScalarFunction)): self._extra_results[varname if varname else vardef.input_name] = vardef else: raise NotImplementedError diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 4ba70e2f..0bede175 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -2,7 +2,7 @@ import re from collections import defaultdict from dataclasses import dataclass -from typing import Any, List, Optional +from typing import Any, List, Optional, Union from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase @@ -942,6 +942,21 @@ def __str__(self): return f"collect({self.input_name})" +@dataclass +class ScalarFunction: + """Base scalar function class.""" + + input_name: Union[str, AggregatingFunction] + + +@dataclass +class Last(ScalarFunction): + """last() function.""" + + def __str__(self) -> str: + return f"last({str(self.input_name)})" + + class NodeSet(BaseSet): """ A class representing as set of nodes matching common query parameters @@ -1165,8 +1180,8 @@ def traverse_relations(self, *relation_names, **aliased_relation_names): def annotate(self, *vars, **aliased_vars): """Annotate node set results with extra variables.""" - def register_extra_var(vardef, varname: str = None): - if isinstance(vardef, AggregatingFunction): + def register_extra_var(vardef, varname: Union[str, None] = None): + if isinstance(vardef, (AggregatingFunction, ScalarFunction)): self._extra_results[varname if varname else vardef.input_name] = vardef else: raise NotImplementedError diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index c4ff24d9..043aef23 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -24,6 +24,7 @@ AsyncQueryBuilder, AsyncTraversal, Collect, + Last, Optional, ) from neomodel.exceptions import MultipleNodesReturned, RelationshipClassNotDefined @@ -732,13 +733,14 @@ async def test_subquery(): result = await Coffee.nodes.subquery( Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( - supps=Collect("suppliers") + supps=Last(Collect("suppliers")) ), ["supps"], ) result = await result.all() assert len(result) == 1 - assert len(result[0][0][0]) == 2 + assert len(result[0]) == 2 + assert result[0][0] == supplier1 with raises( RuntimeError, diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 09e19bc1..2ca283e9 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -20,7 +20,14 @@ ) from neomodel._async_compat.util import Util from neomodel.exceptions import MultipleNodesReturned, RelationshipClassNotDefined -from neomodel.sync_.match import Collect, NodeSet, Optional, QueryBuilder, Traversal +from neomodel.sync_.match import ( + Collect, + Last, + NodeSet, + Optional, + QueryBuilder, + Traversal, +) class SupplierRel(StructuredRel): @@ -720,13 +727,14 @@ def test_subquery(): result = Coffee.nodes.subquery( Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( - supps=Collect("suppliers") + supps=Last(Collect("suppliers")) ), ["supps"], ) result = result.all() assert len(result) == 1 - assert len(result[0][0][0]) == 2 + assert len(result[0]) == 2 + assert result[0][0] == supplier1 with raises( RuntimeError, From f0fad2c554c0626d8ebb40062e9ba325975b3e4f Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Wed, 2 Oct 2024 15:03:16 +0200 Subject: [PATCH 10/42] Specify order to make test work --- neomodel/async_/match.py | 2 +- neomodel/sync_/match.py | 2 +- test/async_/test_match_api.py | 2 +- test/sync_/test_match_api.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 7048c86b..304a976f 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -739,7 +739,7 @@ def build_query(self) -> str: returned_items: list[str] = [] if hasattr(self.node_set, "_subqueries"): for subquery, return_set in self.node_set._subqueries: - outer_primary_var: str = self._ast.return_clause + outer_primary_var = self._ast.return_clause query += f" CALL {{ WITH {outer_primary_var} {subquery} }} " returned_items += return_set diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 0bede175..1c4645c4 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -739,7 +739,7 @@ def build_query(self) -> str: returned_items: list[str] = [] if hasattr(self.node_set, "_subqueries"): for subquery, return_set in self.node_set._subqueries: - outer_primary_var: str = self._ast.return_clause + outer_primary_var = self._ast.return_clause query += f" CALL {{ WITH {outer_primary_var} {subquery} }} " returned_items += return_set diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 043aef23..1a1fd56c 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -737,7 +737,7 @@ async def test_subquery(): ), ["supps"], ) - result = await result.all() + result = await result.order_by("name").all() assert len(result) == 1 assert len(result[0]) == 2 assert result[0][0] == supplier1 diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 2ca283e9..1abea05d 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -731,7 +731,7 @@ def test_subquery(): ), ["supps"], ) - result = result.all() + result = result.order_by("name").all() assert len(result) == 1 assert len(result[0]) == 2 assert result[0][0] == supplier1 From 12515e442fb4444feb6ef02775697d09e9aeffe3 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Wed, 2 Oct 2024 15:12:07 +0200 Subject: [PATCH 11/42] Fixed unit test --- test/async_/test_match_api.py | 3 +-- test/sync_/test_match_api.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 1a1fd56c..52e79406 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -737,10 +737,9 @@ async def test_subquery(): ), ["supps"], ) - result = await result.order_by("name").all() + result = await result.all() assert len(result) == 1 assert len(result[0]) == 2 - assert result[0][0] == supplier1 with raises( RuntimeError, diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 1abea05d..a6d23523 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -731,10 +731,9 @@ def test_subquery(): ), ["supps"], ) - result = result.order_by("name").all() + result = result.all() assert len(result) == 1 assert len(result[0]) == 2 - assert result[0][0] == supplier1 with raises( RuntimeError, From 8d980d1d13dc72af106f6f99f57cb2a17dd2ebf5 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Thu, 3 Oct 2024 17:05:52 +0200 Subject: [PATCH 12/42] Added simple way to inject WITH statements --- neomodel/async_/match.py | 133 ++++++++++++++++++++++++---------- neomodel/sync_/match.py | 133 ++++++++++++++++++++++++---------- test/async_/test_match_api.py | 43 ++++++++++- test/sync_/test_match_api.py | 43 ++++++++++- 4 files changed, 272 insertions(+), 80 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 304a976f..01613e5a 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -2,7 +2,9 @@ import re from collections import defaultdict from dataclasses import dataclass -from typing import Any, List, Optional, Union +from typing import Any, List +from typing import Optional as TOptional +from typing import Tuple, Union from neomodel.async_.core import AsyncStructuredNode, adb from neomodel.async_.relationship import AsyncStructuredRel @@ -375,34 +377,34 @@ def process_has_args(cls, kwargs): class QueryAST: - match: Optional[list] - optional_match: Optional[list] - where: Optional[list] - with_clause: Optional[str] - return_clause: Optional[str] - order_by: Optional[str] - skip: Optional[int] - limit: Optional[int] - result_class: Optional[type] - lookup: Optional[str] - additional_return: Optional[list] - is_count: Optional[bool] + match: TOptional[list] + optional_match: TOptional[list] + where: TOptional[list] + with_clause: TOptional[str] + return_clause: TOptional[str] + order_by: TOptional[str] + skip: TOptional[int] + limit: TOptional[int] + result_class: TOptional[type] + lookup: TOptional[str] + additional_return: TOptional[list] + is_count: TOptional[bool] def __init__( self, - match: Optional[list] = None, - optional_match: Optional[list] = None, - where: Optional[list] = None, - with_clause: Optional[str] = None, - return_clause: Optional[str] = None, - order_by: Optional[str] = None, - skip: Optional[int] = None, - limit: Optional[int] = None, - result_class: Optional[type] = None, - lookup: Optional[str] = None, - additional_return: Optional[list] = None, - is_count: Optional[bool] = False, - ): + match: TOptional[list] = None, + optional_match: TOptional[list] = None, + where: TOptional[list] = None, + with_clause: TOptional[str] = None, + return_clause: TOptional[str] = None, + order_by: TOptional[str] = None, + skip: TOptional[int] = None, + limit: TOptional[int] = None, + result_class: TOptional[type] = None, + lookup: TOptional[str] = None, + additional_return: TOptional[list] = None, + is_count: TOptional[bool] = False, + ) -> None: self.match = match if match else [] self.optional_match = optional_match if optional_match else [] self.where = where if where else [] @@ -421,7 +423,7 @@ def __init__( class AsyncQueryBuilder: def __init__( self, node_set, with_subgraph: bool = False, subquery_context: bool = False - ): + ) -> None: self.node_set = node_set self._ast = QueryAST() self._query_params = {} @@ -430,8 +432,9 @@ def __init__( self._node_counters = defaultdict(int) self._with_subgraph: bool = with_subgraph self._subquery_context: bool = subquery_context + self._relation_identifiers: dict[str, str] = {} - async def build_ast(self): + async def build_ast(self) -> "AsyncQueryBuilder": if hasattr(self.node_set, "relations_to_fetch"): for relation in self.node_set.relations_to_fetch: self.build_traversal_from_path(relation, self.node_set.source) @@ -474,9 +477,12 @@ async def build_source(self, source) -> str: return await self.build_node(source) raise ValueError("Unknown source type " + repr(source)) - def create_ident(self): + def create_ident(self, relation_name: TOptional[str] = None) -> str: self._ident_count += 1 - return "r" + str(self._ident_count) + result = f"r{self._ident_count}" + if relation_name: + self._relation_identifiers[relation_name] = result + return result def build_order_by(self, ident, source): if "?" in source.order_by_elements: @@ -524,8 +530,12 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: parts = path.split("__") if self._with_subgraph: subgraph = self._ast.subgraph + rel_iterator: str = "" for index, part in enumerate(parts): relationship = getattr(source_class_iterator, part) + if rel_iterator: + rel_iterator += "__" + rel_iterator += part # build source if "node_class" not in relationship.definition: relationship.lookup_node_class() @@ -559,7 +569,7 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: else: lhs_ident = stmt - rel_ident = self.create_ident() + rel_ident = self.create_ident(rel_iterator) if self._with_subgraph and part not in self._ast.subgraph: subgraph[part] = { "target": relationship.definition["node_class"], @@ -736,6 +746,32 @@ def build_query(self) -> str: query += " WITH " query += self._ast.with_clause + if hasattr(self.node_set, "_intermediate_transforms"): + for transform in self.node_set._intermediate_transforms: + query += " WITH " + injected_vars: list = [] + for name, source in transform["vars"].items(): + if type(source) is str: + injected_vars.append(f"{source} AS {name}") + elif isinstance(source, RelationNameResolver): + internal_name = self._relation_identifiers.get(source.relation) + if not internal_name: + raise ValueError( + f"Unable to resolve variable name for relation {source.relation}." + ) + injected_vars.append(f"{internal_name} AS {name}") + query += ",".join(injected_vars) + if not transform["ordering"]: + continue + query += " ORDER BY " + ordering: list = [] + for item in transform["ordering"]: + if item.startswith("-"): + ordering.append(f"{item[1:]} DESC") + else: + ordering.append(item) + query += ",".join(ordering) + returned_items: list[str] = [] if hasattr(self.node_set, "_subqueries"): for subquery, return_set in self.node_set._subqueries: @@ -959,12 +995,25 @@ def __str__(self) -> str: return f"last({str(self.input_name)})" +@dataclass +class RelationNameResolver: + """Helper to refer to a relation variable name. + + Since variable names are generated automatically within MATCH statements (for + anything injected using fetch_relations or traverse_relations), we need a way to + retrieve them. + + """ + + relation: str + + class AsyncNodeSet(AsyncBaseSet): """ A class representing as set of nodes matching common query parameters """ - def __init__(self, source): + def __init__(self, source) -> None: self.source = source # could be a Traverse object or a node class if isinstance(source, AsyncTraversal): self.source_class = source.target_class @@ -985,9 +1034,10 @@ def __init__(self, source): self.must_match = {} self.dont_match = {} - self.relations_to_fetch: list = [] - self._extra_results: dict[str] = {} - self._subqueries: list[tuple(str, list[str])] = [] + self.relations_to_fetch: List = [] + self._extra_results: dict = {} + self._subqueries: list[Tuple[str, list[str]]] = [] + self._intermediate_transforms: list = [] def __await__(self): return self.all().__await__() @@ -1052,7 +1102,7 @@ async def first_or_none(self, **kwargs): pass return None - def filter(self, *args, **kwargs): + def filter(self, *args, **kwargs) -> "AsyncBaseSet": """ Apply filters to the existing nodes in the set. @@ -1287,6 +1337,15 @@ async def subquery( self._subqueries.append((qbuilder.build_query(), return_set)) return self + def intermediate_transform( + self, vars: dict[str, Any], ordering: TOptional[list] = None + ) -> "AsyncNodeSet": + for name, source in vars.items(): + if type(source) is not str and not isinstance(source, RelationNameResolver): + raise ValueError(f"Source type invalid") + self._intermediate_transforms.append({"vars": vars, "ordering": ordering}) + return self + class AsyncTraversal(AsyncBaseSet): """ @@ -1306,7 +1365,7 @@ class AsyncTraversal(AsyncBaseSet): def __await__(self): return self.all().__await__() - def __init__(self, source, name, definition): + def __init__(self, source, name, definition) -> None: """ Create a traversal diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 1c4645c4..3954ea18 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -2,7 +2,9 @@ import re from collections import defaultdict from dataclasses import dataclass -from typing import Any, List, Optional, Union +from typing import Any, List +from typing import Optional as TOptional +from typing import Tuple, Union from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase @@ -375,34 +377,34 @@ def process_has_args(cls, kwargs): class QueryAST: - match: Optional[list] - optional_match: Optional[list] - where: Optional[list] - with_clause: Optional[str] - return_clause: Optional[str] - order_by: Optional[str] - skip: Optional[int] - limit: Optional[int] - result_class: Optional[type] - lookup: Optional[str] - additional_return: Optional[list] - is_count: Optional[bool] + match: TOptional[list] + optional_match: TOptional[list] + where: TOptional[list] + with_clause: TOptional[str] + return_clause: TOptional[str] + order_by: TOptional[str] + skip: TOptional[int] + limit: TOptional[int] + result_class: TOptional[type] + lookup: TOptional[str] + additional_return: TOptional[list] + is_count: TOptional[bool] def __init__( self, - match: Optional[list] = None, - optional_match: Optional[list] = None, - where: Optional[list] = None, - with_clause: Optional[str] = None, - return_clause: Optional[str] = None, - order_by: Optional[str] = None, - skip: Optional[int] = None, - limit: Optional[int] = None, - result_class: Optional[type] = None, - lookup: Optional[str] = None, - additional_return: Optional[list] = None, - is_count: Optional[bool] = False, - ): + match: TOptional[list] = None, + optional_match: TOptional[list] = None, + where: TOptional[list] = None, + with_clause: TOptional[str] = None, + return_clause: TOptional[str] = None, + order_by: TOptional[str] = None, + skip: TOptional[int] = None, + limit: TOptional[int] = None, + result_class: TOptional[type] = None, + lookup: TOptional[str] = None, + additional_return: TOptional[list] = None, + is_count: TOptional[bool] = False, + ) -> None: self.match = match if match else [] self.optional_match = optional_match if optional_match else [] self.where = where if where else [] @@ -421,7 +423,7 @@ def __init__( class QueryBuilder: def __init__( self, node_set, with_subgraph: bool = False, subquery_context: bool = False - ): + ) -> None: self.node_set = node_set self._ast = QueryAST() self._query_params = {} @@ -430,8 +432,9 @@ def __init__( self._node_counters = defaultdict(int) self._with_subgraph: bool = with_subgraph self._subquery_context: bool = subquery_context + self._relation_identifiers: dict[str, str] = {} - def build_ast(self): + def build_ast(self) -> "QueryBuilder": if hasattr(self.node_set, "relations_to_fetch"): for relation in self.node_set.relations_to_fetch: self.build_traversal_from_path(relation, self.node_set.source) @@ -474,9 +477,12 @@ def build_source(self, source) -> str: return self.build_node(source) raise ValueError("Unknown source type " + repr(source)) - def create_ident(self): + def create_ident(self, relation_name: TOptional[str] = None) -> str: self._ident_count += 1 - return "r" + str(self._ident_count) + result = f"r{self._ident_count}" + if relation_name: + self._relation_identifiers[relation_name] = result + return result def build_order_by(self, ident, source): if "?" in source.order_by_elements: @@ -524,8 +530,12 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: parts = path.split("__") if self._with_subgraph: subgraph = self._ast.subgraph + rel_iterator: str = "" for index, part in enumerate(parts): relationship = getattr(source_class_iterator, part) + if rel_iterator: + rel_iterator += "__" + rel_iterator += part # build source if "node_class" not in relationship.definition: relationship.lookup_node_class() @@ -559,7 +569,7 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: else: lhs_ident = stmt - rel_ident = self.create_ident() + rel_ident = self.create_ident(rel_iterator) if self._with_subgraph and part not in self._ast.subgraph: subgraph[part] = { "target": relationship.definition["node_class"], @@ -736,6 +746,32 @@ def build_query(self) -> str: query += " WITH " query += self._ast.with_clause + if hasattr(self.node_set, "_intermediate_transforms"): + for transform in self.node_set._intermediate_transforms: + query += " WITH " + injected_vars: list = [] + for name, source in transform["vars"].items(): + if type(source) is str: + injected_vars.append(f"{source} AS {name}") + elif isinstance(source, RelationNameResolver): + internal_name = self._relation_identifiers.get(source.relation) + if not internal_name: + raise ValueError( + f"Unable to resolve variable name for relation {source.relation}." + ) + injected_vars.append(f"{internal_name} AS {name}") + query += ",".join(injected_vars) + if not transform["ordering"]: + continue + query += " ORDER BY " + ordering: list = [] + for item in transform["ordering"]: + if item.startswith("-"): + ordering.append(f"{item[1:]} DESC") + else: + ordering.append(item) + query += ",".join(ordering) + returned_items: list[str] = [] if hasattr(self.node_set, "_subqueries"): for subquery, return_set in self.node_set._subqueries: @@ -957,12 +993,25 @@ def __str__(self) -> str: return f"last({str(self.input_name)})" +@dataclass +class RelationNameResolver: + """Helper to refer to a relation variable name. + + Since variable names are generated automatically within MATCH statements (for + anything injected using fetch_relations or traverse_relations), we need a way to + retrieve them. + + """ + + relation: str + + class NodeSet(BaseSet): """ A class representing as set of nodes matching common query parameters """ - def __init__(self, source): + def __init__(self, source) -> None: self.source = source # could be a Traverse object or a node class if isinstance(source, Traversal): self.source_class = source.target_class @@ -983,9 +1032,10 @@ def __init__(self, source): self.must_match = {} self.dont_match = {} - self.relations_to_fetch: list = [] - self._extra_results: dict[str] = {} - self._subqueries: list[tuple(str, list[str])] = [] + self.relations_to_fetch: List = [] + self._extra_results: dict = {} + self._subqueries: list[Tuple[str, list[str]]] = [] + self._intermediate_transforms: list = [] def __await__(self): return self.all().__await__() @@ -1050,7 +1100,7 @@ def first_or_none(self, **kwargs): pass return None - def filter(self, *args, **kwargs): + def filter(self, *args, **kwargs) -> "BaseSet": """ Apply filters to the existing nodes in the set. @@ -1283,6 +1333,15 @@ def subquery(self, nodeset: "NodeSet", return_set: List[str]) -> "NodeSet": self._subqueries.append((qbuilder.build_query(), return_set)) return self + def intermediate_transform( + self, vars: dict[str, Any], ordering: TOptional[list] = None + ) -> "NodeSet": + for name, source in vars.items(): + if type(source) is not str and not isinstance(source, RelationNameResolver): + raise ValueError(f"Source type invalid") + self._intermediate_transforms.append({"vars": vars, "ordering": ordering}) + return self + class Traversal(BaseSet): """ @@ -1302,7 +1361,7 @@ class Traversal(BaseSet): def __await__(self): return self.all().__await__() - def __init__(self, source, name, definition): + def __init__(self, source, name, definition) -> None: """ Create a traversal diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 52e79406..0451431f 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -26,6 +26,7 @@ Collect, Last, Optional, + RelationNameResolver, ) from neomodel.exceptions import MultipleNodesReturned, RelationshipClassNotDefined @@ -732,14 +733,17 @@ async def test_subquery(): await nescafe.species.connect(arabica) result = await Coffee.nodes.subquery( - Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( - supps=Last(Collect("suppliers")) - ), + Coffee.nodes.traverse_relations(suppliers="suppliers") + .intermediate_transform( + {"suppliers": "suppliers"}, ordering=["suppliers.delivery_cost"] + ) + .annotate(supps=Last(Collect("suppliers"))), ["supps"], ) result = await result.all() assert len(result) == 1 assert len(result[0]) == 2 + assert result[0][0] == supplier2 with raises( RuntimeError, @@ -753,6 +757,39 @@ async def test_subquery(): ) +@mark_async_test +async def test_intermediate_transform(): + # Clean DB before we start anything... + await adb.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = await Species(name="Arabica").save() + nescafe = await Coffee(name="Nescafe", price=99).save() + supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save() + supplier2 = await Supplier(name="Supplier 2", delivery_cost=20).save() + + await nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)}) + await nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)}) + await nescafe.species.connect(arabica) + + result = ( + await Coffee.nodes.traverse_relations(suppliers="suppliers") + .intermediate_transform( + { + "coffee": "coffee", + "suppliers": "suppliers", + "r": RelationNameResolver("suppliers"), + }, + ordering=["-r.since"], + ) + .annotate(oldest_supplier=Last(Collect("suppliers"))) + .all() + ) + + assert len(result) == 1 + assert len(result[0]) == 2 + assert result[0][1] == supplier2 + + @mark_async_test async def test_issue_795(): jim = await PersonX(name="Jim", age=3).save() # Create diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index a6d23523..df16f3b1 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -26,6 +26,7 @@ NodeSet, Optional, QueryBuilder, + RelationNameResolver, Traversal, ) @@ -726,14 +727,17 @@ def test_subquery(): nescafe.species.connect(arabica) result = Coffee.nodes.subquery( - Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( - supps=Last(Collect("suppliers")) - ), + Coffee.nodes.traverse_relations(suppliers="suppliers") + .intermediate_transform( + {"suppliers": "suppliers"}, ordering=["suppliers.delivery_cost"] + ) + .annotate(supps=Last(Collect("suppliers"))), ["supps"], ) result = result.all() assert len(result) == 1 assert len(result[0]) == 2 + assert result[0][0] == supplier2 with raises( RuntimeError, @@ -747,6 +751,39 @@ def test_subquery(): ) +@mark_sync_test +def test_intermediate_transform(): + # Clean DB before we start anything... + db.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = Species(name="Arabica").save() + nescafe = Coffee(name="Nescafe", price=99).save() + supplier1 = Supplier(name="Supplier 1", delivery_cost=3).save() + supplier2 = Supplier(name="Supplier 2", delivery_cost=20).save() + + nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)}) + nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)}) + nescafe.species.connect(arabica) + + result = ( + Coffee.nodes.traverse_relations(suppliers="suppliers") + .intermediate_transform( + { + "coffee": "coffee", + "suppliers": "suppliers", + "r": RelationNameResolver("suppliers"), + }, + ordering=["-r.since"], + ) + .annotate(oldest_supplier=Last(Collect("suppliers"))) + .all() + ) + + assert len(result) == 1 + assert len(result[0]) == 2 + assert result[0][1] == supplier2 + + @mark_sync_test def test_issue_795(): jim = PersonX(name="Jim", age=3).save() # Create From 27d6ecc5e4fbc158b578a8e08bad4c1036859b05 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Thu, 3 Oct 2024 17:09:43 +0200 Subject: [PATCH 13/42] Fixed python 3.7 compat --- neomodel/async_/match.py | 4 ++-- neomodel/sync_/match.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 01613e5a..120d5aed 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -2,7 +2,7 @@ import re from collections import defaultdict from dataclasses import dataclass -from typing import Any, List +from typing import Any, Dict, List from typing import Optional as TOptional from typing import Tuple, Union @@ -432,7 +432,7 @@ def __init__( self._node_counters = defaultdict(int) self._with_subgraph: bool = with_subgraph self._subquery_context: bool = subquery_context - self._relation_identifiers: dict[str, str] = {} + self._relation_identifiers: Dict[str, str] = {} async def build_ast(self) -> "AsyncQueryBuilder": if hasattr(self.node_set, "relations_to_fetch"): diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 3954ea18..36767e6e 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -2,7 +2,7 @@ import re from collections import defaultdict from dataclasses import dataclass -from typing import Any, List +from typing import Any, Dict, List from typing import Optional as TOptional from typing import Tuple, Union @@ -432,7 +432,7 @@ def __init__( self._node_counters = defaultdict(int) self._with_subgraph: bool = with_subgraph self._subquery_context: bool = subquery_context - self._relation_identifiers: dict[str, str] = {} + self._relation_identifiers: Dict[str, str] = {} def build_ast(self) -> "QueryBuilder": if hasattr(self.node_set, "relations_to_fetch"): From 35f46236de2d60408a4e99afe8a917fc368633e2 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Thu, 3 Oct 2024 17:16:49 +0200 Subject: [PATCH 14/42] Added test --- neomodel/async_/match.py | 6 ++++-- neomodel/sync_/match.py | 6 ++++-- test/async_/test_match_api.py | 12 ++++++++++++ test/sync_/test_match_api.py | 12 ++++++++++++ 4 files changed, 32 insertions(+), 4 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 120d5aed..0d9da69a 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -1338,11 +1338,13 @@ async def subquery( return self def intermediate_transform( - self, vars: dict[str, Any], ordering: TOptional[list] = None + self, vars: Dict[str, Any], ordering: TOptional[list] = None ) -> "AsyncNodeSet": for name, source in vars.items(): if type(source) is not str and not isinstance(source, RelationNameResolver): - raise ValueError(f"Source type invalid") + raise ValueError( + f"Wrong source type specified for variable '{name}', should be a string or an instance of RelationNameResolver" + ) self._intermediate_transforms.append({"vars": vars, "ordering": ordering}) return self diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 36767e6e..1075ef40 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -1334,11 +1334,13 @@ def subquery(self, nodeset: "NodeSet", return_set: List[str]) -> "NodeSet": return self def intermediate_transform( - self, vars: dict[str, Any], ordering: TOptional[list] = None + self, vars: Dict[str, Any], ordering: TOptional[list] = None ) -> "NodeSet": for name, source in vars.items(): if type(source) is not str and not isinstance(source, RelationNameResolver): - raise ValueError(f"Source type invalid") + raise ValueError( + f"Wrong source type specified for variable '{name}', should be a string or an instance of RelationNameResolver" + ) self._intermediate_transforms.append({"vars": vars, "ordering": ordering}) return self diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 0451431f..0659f468 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -789,6 +789,18 @@ async def test_intermediate_transform(): assert len(result[0]) == 2 assert result[0][1] == supplier2 + with raises( + ValueError, + match=re.escape( + r"Wrong source type specified for variable 'test', should be a string or an instance of RelationNameResolver" + ), + ): + Coffee.nodes.traverse_relations(suppliers="suppliers").intermediate_transform( + { + "test": Collect("suppliers"), + } + ) + @mark_async_test async def test_issue_795(): diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index df16f3b1..3976c92a 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -783,6 +783,18 @@ def test_intermediate_transform(): assert len(result[0]) == 2 assert result[0][1] == supplier2 + with raises( + ValueError, + match=re.escape( + r"Wrong source type specified for variable 'test', should be a string or an instance of RelationNameResolver" + ), + ): + Coffee.nodes.traverse_relations(suppliers="suppliers").intermediate_transform( + { + "test": Collect("suppliers"), + } + ) + @mark_sync_test def test_issue_795(): From 2a97f7539351d194a8f6dfaa1b1aa0677397ff9c Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Thu, 3 Oct 2024 17:36:09 +0200 Subject: [PATCH 15/42] Better type annotations --- neomodel/async_/match.py | 33 ++++++++++++++++++--------------- neomodel/sync_/match.py | 33 ++++++++++++++++++--------------- 2 files changed, 36 insertions(+), 30 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 0d9da69a..c0919aca 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -216,7 +216,7 @@ def install_traversals(cls, node_set): setattr(node_set, key, traversal) -def process_filter_args(cls, kwargs): +def process_filter_args(cls, kwargs) -> Dict: """ loop through properties in filter parameters check they match class definition deflate them and convert into something easy to generate cypher from @@ -377,8 +377,8 @@ def process_has_args(cls, kwargs): class QueryAST: - match: TOptional[list] - optional_match: TOptional[list] + match: List[str] + optional_match: List[str] where: TOptional[list] with_clause: TOptional[str] return_clause: TOptional[str] @@ -387,13 +387,13 @@ class QueryAST: limit: TOptional[int] result_class: TOptional[type] lookup: TOptional[str] - additional_return: TOptional[list] + additional_return: List[str] is_count: TOptional[bool] def __init__( self, - match: TOptional[list] = None, - optional_match: TOptional[list] = None, + match: TOptional[List[str]] = None, + optional_match: TOptional[List[str]] = None, where: TOptional[list] = None, with_clause: TOptional[str] = None, return_clause: TOptional[str] = None, @@ -402,7 +402,7 @@ def __init__( limit: TOptional[int] = None, result_class: TOptional[type] = None, lookup: TOptional[str] = None, - additional_return: TOptional[list] = None, + additional_return: TOptional[List[str]] = None, is_count: TOptional[bool] = False, ) -> None: self.match = match if match else [] @@ -417,7 +417,7 @@ def __init__( self.lookup = lookup self.additional_return = additional_return if additional_return else [] self.is_count = is_count - self.subgraph: dict = {} + self.subgraph: Dict = {} class AsyncQueryBuilder: @@ -903,7 +903,7 @@ async def get_len(self): ast = await self.query_cls(self).build_ast() return await ast._count() - async def check_bool(self): + async def check_bool(self) -> bool: """ Override for __bool__ dunder method. :return: True if the set contains any nodes, False otherwise @@ -913,7 +913,7 @@ async def check_bool(self): _count = await ast._count() return _count > 0 - async def check_nonzero(self): + async def check_nonzero(self) -> bool: """ Override for __bool__ dunder method. :return: True if the set contains any node, False otherwise @@ -1027,12 +1027,12 @@ def __init__(self, source) -> None: # setup Traversal objects using relationship definitions install_traversals(self.source_class, self) - self.filters = [] + self.filters: List = [] self.q_filters = Q() # used by has() - self.must_match = {} - self.dont_match = {} + self.must_match: Dict = {} + self.dont_match: Dict = {} self.relations_to_fetch: List = [] self._extra_results: dict = {} @@ -1193,7 +1193,10 @@ def order_by(self, *props): return self def _register_relation_to_fetch( - self, relation_def: Any, alias: str = None, include_in_return: bool = True + self, + relation_def: Any, + alias: TOptional[str] = None, + include_in_return: bool = True, ): if isinstance(relation_def, Optional): item = {"path": relation_def.relation, "optional": True} @@ -1397,7 +1400,7 @@ def __init__(self, source, name, definition) -> None: self.definition = definition self.target_class = definition["node_class"] self.name = name - self.filters = [] + self.filters: List = [] def match(self, **kwargs): """ diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 1075ef40..16ad5350 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -216,7 +216,7 @@ def install_traversals(cls, node_set): setattr(node_set, key, traversal) -def process_filter_args(cls, kwargs): +def process_filter_args(cls, kwargs) -> Dict: """ loop through properties in filter parameters check they match class definition deflate them and convert into something easy to generate cypher from @@ -377,8 +377,8 @@ def process_has_args(cls, kwargs): class QueryAST: - match: TOptional[list] - optional_match: TOptional[list] + match: List[str] + optional_match: List[str] where: TOptional[list] with_clause: TOptional[str] return_clause: TOptional[str] @@ -387,13 +387,13 @@ class QueryAST: limit: TOptional[int] result_class: TOptional[type] lookup: TOptional[str] - additional_return: TOptional[list] + additional_return: List[str] is_count: TOptional[bool] def __init__( self, - match: TOptional[list] = None, - optional_match: TOptional[list] = None, + match: TOptional[List[str]] = None, + optional_match: TOptional[List[str]] = None, where: TOptional[list] = None, with_clause: TOptional[str] = None, return_clause: TOptional[str] = None, @@ -402,7 +402,7 @@ def __init__( limit: TOptional[int] = None, result_class: TOptional[type] = None, lookup: TOptional[str] = None, - additional_return: TOptional[list] = None, + additional_return: TOptional[List[str]] = None, is_count: TOptional[bool] = False, ) -> None: self.match = match if match else [] @@ -417,7 +417,7 @@ def __init__( self.lookup = lookup self.additional_return = additional_return if additional_return else [] self.is_count = is_count - self.subgraph: dict = {} + self.subgraph: Dict = {} class QueryBuilder: @@ -901,7 +901,7 @@ def __len__(self): ast = self.query_cls(self).build_ast() return ast._count() - def __bool__(self): + def __bool__(self) -> bool: """ Override for __bool__ dunder method. :return: True if the set contains any nodes, False otherwise @@ -911,7 +911,7 @@ def __bool__(self): _count = ast._count() return _count > 0 - def __nonzero__(self): + def __nonzero__(self) -> bool: """ Override for __bool__ dunder method. :return: True if the set contains any node, False otherwise @@ -1025,12 +1025,12 @@ def __init__(self, source) -> None: # setup Traversal objects using relationship definitions install_traversals(self.source_class, self) - self.filters = [] + self.filters: List = [] self.q_filters = Q() # used by has() - self.must_match = {} - self.dont_match = {} + self.must_match: Dict = {} + self.dont_match: Dict = {} self.relations_to_fetch: List = [] self._extra_results: dict = {} @@ -1191,7 +1191,10 @@ def order_by(self, *props): return self def _register_relation_to_fetch( - self, relation_def: Any, alias: str = None, include_in_return: bool = True + self, + relation_def: Any, + alias: TOptional[str] = None, + include_in_return: bool = True, ): if isinstance(relation_def, Optional): item = {"path": relation_def.relation, "optional": True} @@ -1393,7 +1396,7 @@ def __init__(self, source, name, definition) -> None: self.definition = definition self.target_class = definition["node_class"] self.name = name - self.filters = [] + self.filters: List = [] def match(self, **kwargs): """ From 217d0c5d7e4a25f14ce064ae134cde73637a328c Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Tue, 8 Oct 2024 11:44:54 +0200 Subject: [PATCH 16/42] Added traversal support to filtering and ordering features --- neomodel/async_/match.py | 358 ++++++++++++++++++---------------- neomodel/properties.py | 13 +- neomodel/sync_/match.py | 358 ++++++++++++++++++---------------- test/async_/test_match_api.py | 50 ++++- test/sync_/test_match_api.py | 48 ++++- 5 files changed, 495 insertions(+), 332 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index c0919aca..fa7e0bf6 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -6,11 +6,12 @@ from typing import Optional as TOptional from typing import Tuple, Union +from neomodel.async_ import relationship_manager from neomodel.async_.core import AsyncStructuredNode, adb from neomodel.async_.relationship import AsyncStructuredRel from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase -from neomodel.properties import AliasProperty, ArrayProperty +from neomodel.properties import AliasProperty, ArrayProperty, Property from neomodel.util import INCOMING, OUTGOING @@ -197,6 +198,8 @@ def _rel_merge_helper( # add all regex operators OPERATOR_TABLE.update(_REGEX_OPERATOR_TABLE) +path_split_regex = re.compile(r"__(?!_)") + def install_traversals(cls, node_set): """ @@ -216,136 +219,111 @@ def install_traversals(cls, node_set): setattr(node_set, key, traversal) -def process_filter_args(cls, kwargs) -> Dict: - """ - loop through properties in filter parameters check they match class definition - deflate them and convert into something easy to generate cypher from - """ - - output = {} - - for key, value in kwargs.items(): - if "__" in key: - prop, operator = key.rsplit("__") - operator = OPERATOR_TABLE[operator] - else: - prop = key - operator = "=" - - if prop not in cls.defined_properties(rels=False): +def _handle_special_operators( + property_obj: Property, key: str, value: str, operator: str, prop: str +) -> Tuple[str, str, str]: + if operator == _SPECIAL_OPERATOR_IN: + if not isinstance(value, (list, tuple)): raise ValueError( - f"No such property {prop} on {cls.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation." + f"Value must be a tuple or list for IN operation {key}={value}" ) - - property_obj = getattr(cls, prop) - if isinstance(property_obj, AliasProperty): - prop = property_obj.aliased_to() - deflated_value = getattr(cls, prop).deflate(value) + if isinstance(property_obj, ArrayProperty): + deflated_value = property_obj.deflate(value) + operator = _SPECIAL_OPERATOR_ARRAY_IN else: - operator, deflated_value = transform_operator_to_filter( - operator=operator, - filter_key=key, - filter_value=value, - property_obj=property_obj, - ) + deflated_value = [property_obj.deflate(v) for v in value] + elif operator == _SPECIAL_OPERATOR_ISNULL: + if not isinstance(value, bool): + raise ValueError(f"Value must be a bool for isnull operation on {key}") + operator = "IS NULL" if value else "IS NOT NULL" + deflated_value = None + elif operator in _REGEX_OPERATOR_TABLE.values(): + deflated_value = property_obj.deflate(value) + if not isinstance(deflated_value, str): + raise ValueError(f"Must be a string value for {key}") + if operator in _STRING_REGEX_OPERATOR_TABLE.values(): + deflated_value = re.escape(deflated_value) + deflated_value = operator.format(deflated_value) + operator = _SPECIAL_OPERATOR_REGEX + else: + deflated_value = property_obj.deflate(value) - # map property to correct property name in the database - db_property = cls.defined_properties(rels=False)[prop].get_db_property_name( - prop + return deflated_value, operator, prop + + +def _deflate_value( + cls, property_obj: Property, key: str, value: str, operator: str, prop: str +) -> Tuple[str, str, str]: + if isinstance(property_obj, AliasProperty): + prop = property_obj.aliased_to() + deflated_value = getattr(cls, prop).deflate(value) + else: + # handle special operators + deflated_value, operator, prop = _handle_special_operators( + property_obj, key, value, operator, prop ) - output[db_property] = (operator, deflated_value) + return deflated_value, operator, prop - return output +def _initialize_filter_args_variables(cls, key: str): + current_class = cls + leaf_prop = None + operator = "=" + prop = key -def transform_in_operator_to_filter(operator, filter_key, filter_value, property_obj): - """ - Transform in operator to a cypher filter - Args: - operator (str): operator to transform - filter_key (str): filter key - filter_value (str): filter value - property_obj (object): property object - Returns: - tuple: operator, deflated_value - """ - if not isinstance(filter_value, tuple) and not isinstance(filter_value, list): - raise ValueError( - f"Value must be a tuple or list for IN operation {filter_key}={filter_value}" - ) - if isinstance(property_obj, ArrayProperty): - deflated_value = property_obj.deflate(filter_value) - operator = _SPECIAL_OPERATOR_ARRAY_IN - else: - deflated_value = [property_obj.deflate(v) for v in filter_value] + return current_class, leaf_prop, operator, prop - return operator, deflated_value +def _process_filter_key(cls, key: str) -> Tuple[Property, str, str]: + ( + current_class, + leaf_prop, + operator, + prop, + ) = _initialize_filter_args_variables(cls, key) -def transform_null_operator_to_filter(filter_key, filter_value): - """ - Transform null operator to a cypher filter - Args: - filter_key (str): filter key - filter_value (str): filter value - Returns: - tuple: operator, deflated_value - """ - if not isinstance(filter_value, bool): - raise ValueError(f"Value must be a bool for isnull operation on {filter_key}") - operator = "IS NULL" if filter_value else "IS NOT NULL" - deflated_value = None - return operator, deflated_value + for part in re.split(path_split_regex, key): + defined_props = current_class.defined_properties(rels=True) + if part in defined_props: + if isinstance( + defined_props[part], relationship_manager.AsyncRelationshipDefinition + ): + defined_props[part].lookup_node_class() + current_class = defined_props[part].definition["node_class"] + elif part in OPERATOR_TABLE: + operator = OPERATOR_TABLE[part] + prop, _ = prop.rsplit("__", 1) + continue + else: + raise ValueError( + f"No such property {part} on {cls.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation." + ) + leaf_prop = part + property_obj = getattr(current_class, leaf_prop) -def transform_regex_operator_to_filter( - operator, filter_key, filter_value, property_obj -): - """ - Transform regex operator to a cypher filter - Args: - operator (str): operator to transform - filter_key (str): filter key - filter_value (str): filter value - property_obj (object): property object - Returns: - tuple: operator, deflated_value - """ + return property_obj, operator, prop - deflated_value = property_obj.deflate(filter_value) - if not isinstance(deflated_value, str): - raise ValueError(f"Must be a string value for {filter_key}") - if operator in _STRING_REGEX_OPERATOR_TABLE.values(): - deflated_value = re.escape(deflated_value) - deflated_value = operator.format(deflated_value) - operator = _SPECIAL_OPERATOR_REGEX - return operator, deflated_value +def process_filter_args(cls, kwargs) -> Dict: + """ + loop through properties in filter parameters check they match class definition + deflate them and convert into something easy to generate cypher from + """ + output = {} -def transform_operator_to_filter(operator, filter_key, filter_value, property_obj): - if operator == _SPECIAL_OPERATOR_IN: - operator, deflated_value = transform_in_operator_to_filter( - operator=operator, - filter_key=filter_key, - filter_value=filter_value, - property_obj=property_obj, - ) - elif operator == _SPECIAL_OPERATOR_ISNULL: - operator, deflated_value = transform_null_operator_to_filter( - filter_key=filter_key, filter_value=filter_value - ) - elif operator in _REGEX_OPERATOR_TABLE.values(): - operator, deflated_value = transform_regex_operator_to_filter( - operator=operator, - filter_key=filter_key, - filter_value=filter_value, - property_obj=property_obj, + for key, value in kwargs.items(): + property_obj, operator, prop = _process_filter_key(cls, key) + deflated_value, operator, prop = _deflate_value( + cls, property_obj, key, value, operator, prop ) - else: - deflated_value = property_obj.deflate(filter_value) - return operator, deflated_value + # map property to correct property name in the database + db_property = prop + + output[db_property] = (operator, deflated_value) + return output def process_has_args(cls, kwargs): @@ -382,7 +360,7 @@ class QueryAST: where: TOptional[list] with_clause: TOptional[str] return_clause: TOptional[str] - order_by: TOptional[str] + order_by: TOptional[List[str]] skip: TOptional[int] limit: TOptional[int] result_class: TOptional[type] @@ -397,7 +375,7 @@ def __init__( where: TOptional[list] = None, with_clause: TOptional[str] = None, return_clause: TOptional[str] = None, - order_by: TOptional[str] = None, + order_by: TOptional[List[str]] = None, skip: TOptional[int] = None, limit: TOptional[int] = None, result_class: TOptional[type] = None, @@ -421,18 +399,14 @@ def __init__( class AsyncQueryBuilder: - def __init__( - self, node_set, with_subgraph: bool = False, subquery_context: bool = False - ) -> None: + def __init__(self, node_set, subquery_context: bool = False) -> None: self.node_set = node_set self._ast = QueryAST() - self._query_params = {} + self._query_params: Dict = {} self._place_holder_registry = {} - self._ident_count = 0 + self._ident_count: int = 0 self._node_counters = defaultdict(int) - self._with_subgraph: bool = with_subgraph self._subquery_context: bool = subquery_context - self._relation_identifiers: Dict[str, str] = {} async def build_ast(self) -> "AsyncQueryBuilder": if hasattr(self.node_set, "relations_to_fetch"): @@ -477,19 +451,31 @@ async def build_source(self, source) -> str: return await self.build_node(source) raise ValueError("Unknown source type " + repr(source)) - def create_ident(self, relation_name: TOptional[str] = None) -> str: + def create_ident(self) -> str: self._ident_count += 1 - result = f"r{self._ident_count}" - if relation_name: - self._relation_identifiers[relation_name] = result - return result + return f"r{self._ident_count}" - def build_order_by(self, ident, source): + def build_order_by(self, ident: str, source: "AsyncNodeSet") -> None: if "?" in source.order_by_elements: self._ast.with_clause = f"{ident}, rand() as r" - self._ast.order_by = "r" + self._ast.order_by = ["r"] else: - self._ast.order_by = [f"{ident}.{p}" for p in source.order_by_elements] + order_by = [] + for elm in source.order_by_elements: + if "__" not in elm: + prop = elm.split(" ")[0] if " " in elm else elm + if prop not in source.source_class.defined_properties(rels=False): + raise ValueError( + f"No such property {prop} on {source.source_class.__name__}. " + f"Note that Neo4j internals like id or element_id are not allowed " + f"for use in this operation." + ) + order_by.append(f"{ident}.{elm}") + else: + path, prop = elm.rsplit("__", 1) + order_by_clause = self.lookup_query_variable(path) + order_by.append(f"{order_by_clause}.{prop}") + self._ast.order_by = order_by async def build_traversal(self, traversal): """ @@ -527,9 +513,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: path: str = relation["path"] stmt: str = "" source_class_iterator = source_class - parts = path.split("__") - if self._with_subgraph: - subgraph = self._ast.subgraph + parts = re.split(path_split_regex, path) + subgraph = self._ast.subgraph rel_iterator: str = "" for index, part in enumerate(parts): relationship = getattr(source_class_iterator, part) @@ -569,8 +554,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: else: lhs_ident = stmt - rel_ident = self.create_ident(rel_iterator) - if self._with_subgraph and part not in self._ast.subgraph: + rel_ident = self.create_ident() + if part not in self._ast.subgraph: subgraph[part] = { "target": relationship.definition["node_class"], "children": {}, @@ -587,8 +572,7 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: relation_type=relationship.definition["relation_type"], ) source_class_iterator = relationship.definition["node_class"] - if self._with_subgraph: - subgraph = subgraph[part]["children"] + subgraph = subgraph[part]["children"] if relation.get("optional"): self._ast.optional_match.append(stmt) @@ -653,6 +637,41 @@ def _register_place_holder(self, key): self._place_holder_registry[key] = 1 return key + "_" + str(self._place_holder_registry[key]) + def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str]: + path, prop = prop.rsplit("__", 1) + ident = self.build_traversal_from_path( + {"path": path, "include_in_return": True}, + source_class, + ) + return ident, path, prop + + def _finalize_filter_statement(self, operator, ident, prop, val) -> str: + if operator in _UNARY_OPERATORS: + # unary operators do not have a parameter + statement = f"{ident}.{prop} {operator}" + else: + place_holder = self._register_place_holder(ident + "_" + prop) + if operator == _SPECIAL_OPERATOR_ARRAY_IN: + statement = operator.format( + ident=ident, + prop=prop, + val=f"${place_holder}", + ) + else: + statement = f"{ident}.{prop} {operator} ${place_holder}" + self._query_params[place_holder] = val + + return statement + + def _build_filter_statements(self, ident, filters, target, source_class): + for prop, op_and_val in filters.items(): + path = None + if "__" in prop: + ident, path, prop = self._parse_path(source_class, prop) + operator, val = op_and_val + statement = self._finalize_filter_statement(operator, ident, prop, val) + target.append(statement) + def _parse_q_filters(self, ident, q, source_class): target = [] for child in q.children: @@ -664,23 +683,7 @@ def _parse_q_filters(self, ident, q, source_class): else: kwargs = {child[0]: child[1]} filters = process_filter_args(source_class, kwargs) - for prop, op_and_val in filters.items(): - operator, val = op_and_val - if operator in _UNARY_OPERATORS: - # unary operators do not have a parameter - statement = f"{ident}.{prop} {operator}" - else: - place_holder = self._register_place_holder(ident + "_" + prop) - if operator == _SPECIAL_OPERATOR_ARRAY_IN: - statement = operator.format( - ident=ident, - prop=prop, - val=f"${place_holder}", - ) - else: - statement = f"{ident}.{prop} {operator} ${place_holder}" - self._query_params[place_holder] = val - target.append(statement) + self._build_filter_statements(ident, filters, target, source_class) ret = f" {q.connector} ".join(target) if q.negated: ret = f"NOT ({ret})" @@ -719,6 +722,35 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None): self._ast.where.append(" AND ".join(stmts)) + def lookup_query_variable( + self, path: str, return_relation: bool = False + ) -> TOptional[str]: + """Retrieve the variable name generated internally for the given traversal path.""" + subgraph = self._ast.subgraph + if not subgraph: + return None + traversals = re.split(path_split_regex, path) + if len(traversals) == 0: + raise ValueError("Can only lookup traversal variables") + if traversals[0] not in subgraph: + return None + subgraph = subgraph[traversals[0]] + variable_to_return = None + last_property = traversals[-1] + for part in traversals: + if part in subgraph["children"]: + subgraph = subgraph["children"][part] + elif part == last_property: + # if last part of prop is the last traversal + # we are safe to lookup the variable from the query + if return_relation: + variable_to_return = f"{subgraph['rel_variable_name']}" + else: + variable_to_return = f"{subgraph['variable_name']}" + else: + break + return variable_to_return + def build_query(self) -> str: query: str = "" @@ -754,7 +786,9 @@ def build_query(self) -> str: if type(source) is str: injected_vars.append(f"{source} AS {name}") elif isinstance(source, RelationNameResolver): - internal_name = self._relation_identifiers.get(source.relation) + internal_name = self.lookup_query_variable( + source.relation, return_relation=True + ) if not internal_name: raise ValueError( f"Unable to resolve variable name for relation {source.relation}." @@ -880,6 +914,7 @@ class AsyncBaseSet: """ query_cls = AsyncQueryBuilder + source_class: AsyncStructuredNode async def all(self, lazy=False): """ @@ -1029,6 +1064,7 @@ def __init__(self, source) -> None: self.filters: List = [] self.q_filters = Q() + self.order_by_elements: List = [] # used by has() self.must_match: Dict = {} @@ -1179,14 +1215,10 @@ def order_by(self, *props): else: desc = False - if prop not in self.source_class.defined_properties(rels=False): - raise ValueError( - f"No such property {prop} on {self.source_class.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation." - ) - - property_obj = getattr(self.source_class, prop) - if isinstance(property_obj, AliasProperty): - prop = property_obj.aliased_to() + if prop in self.source_class.defined_properties(rels=False): + property_obj = getattr(self.source_class, prop) + if isinstance(property_obj, AliasProperty): + prop = property_obj.aliased_to() self.order_by_elements.append(prop + (" DESC" if desc else "")) @@ -1300,7 +1332,7 @@ async def resolve_subgraph(self) -> list: "You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead." ) results: list = [] - qbuilder = self.query_cls(self, with_subgraph=True) + qbuilder = self.query_cls(self) await qbuilder.build_ast() all_nodes = qbuilder._execute(dict_output=True) other_nodes = {} diff --git a/neomodel/properties.py b/neomodel/properties.py index d4a91885..8c848ea8 100644 --- a/neomodel/properties.py +++ b/neomodel/properties.py @@ -1,9 +1,10 @@ import functools import json import re -import sys import uuid +from abc import ABCMeta, abstractmethod from datetime import date, datetime +from typing import Any import neo4j.time import pytz @@ -37,7 +38,7 @@ def _validator(self, value, obj=None, rethrow=True): return _validator -class FulltextIndex(object): +class FulltextIndex: """ Fulltext index definition """ @@ -57,7 +58,7 @@ def __init__( self.eventually_consistent = eventually_consistent -class VectorIndex(object): +class VectorIndex: """ Vector index definition """ @@ -73,7 +74,7 @@ def __init__(self, dimensions=1536, similarity_function="cosine"): self.similarity_function = similarity_function -class Property: +class Property(metaclass=ABCMeta): """ Base class for object properties. @@ -158,6 +159,10 @@ def get_db_property_name(self, attribute_name): def is_indexed(self): return self.unique_index or self.index + @abstractmethod + def deflate(self, value: Any) -> Any: + pass + class NormalizedProperty(Property): """ diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 16ad5350..a1d311fd 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -8,7 +8,8 @@ from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase -from neomodel.properties import AliasProperty, ArrayProperty +from neomodel.properties import AliasProperty, ArrayProperty, Property +from neomodel.sync_ import relationship_manager from neomodel.sync_.core import StructuredNode, db from neomodel.sync_.relationship import StructuredRel from neomodel.util import INCOMING, OUTGOING @@ -197,6 +198,8 @@ def _rel_merge_helper( # add all regex operators OPERATOR_TABLE.update(_REGEX_OPERATOR_TABLE) +path_split_regex = re.compile(r"__(?!_)") + def install_traversals(cls, node_set): """ @@ -216,136 +219,111 @@ def install_traversals(cls, node_set): setattr(node_set, key, traversal) -def process_filter_args(cls, kwargs) -> Dict: - """ - loop through properties in filter parameters check they match class definition - deflate them and convert into something easy to generate cypher from - """ - - output = {} - - for key, value in kwargs.items(): - if "__" in key: - prop, operator = key.rsplit("__") - operator = OPERATOR_TABLE[operator] - else: - prop = key - operator = "=" - - if prop not in cls.defined_properties(rels=False): +def _handle_special_operators( + property_obj: Property, key: str, value: str, operator: str, prop: str +) -> Tuple[str, str, str]: + if operator == _SPECIAL_OPERATOR_IN: + if not isinstance(value, (list, tuple)): raise ValueError( - f"No such property {prop} on {cls.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation." + f"Value must be a tuple or list for IN operation {key}={value}" ) - - property_obj = getattr(cls, prop) - if isinstance(property_obj, AliasProperty): - prop = property_obj.aliased_to() - deflated_value = getattr(cls, prop).deflate(value) + if isinstance(property_obj, ArrayProperty): + deflated_value = property_obj.deflate(value) + operator = _SPECIAL_OPERATOR_ARRAY_IN else: - operator, deflated_value = transform_operator_to_filter( - operator=operator, - filter_key=key, - filter_value=value, - property_obj=property_obj, - ) + deflated_value = [property_obj.deflate(v) for v in value] + elif operator == _SPECIAL_OPERATOR_ISNULL: + if not isinstance(value, bool): + raise ValueError(f"Value must be a bool for isnull operation on {key}") + operator = "IS NULL" if value else "IS NOT NULL" + deflated_value = None + elif operator in _REGEX_OPERATOR_TABLE.values(): + deflated_value = property_obj.deflate(value) + if not isinstance(deflated_value, str): + raise ValueError(f"Must be a string value for {key}") + if operator in _STRING_REGEX_OPERATOR_TABLE.values(): + deflated_value = re.escape(deflated_value) + deflated_value = operator.format(deflated_value) + operator = _SPECIAL_OPERATOR_REGEX + else: + deflated_value = property_obj.deflate(value) - # map property to correct property name in the database - db_property = cls.defined_properties(rels=False)[prop].get_db_property_name( - prop + return deflated_value, operator, prop + + +def _deflate_value( + cls, property_obj: Property, key: str, value: str, operator: str, prop: str +) -> Tuple[str, str, str]: + if isinstance(property_obj, AliasProperty): + prop = property_obj.aliased_to() + deflated_value = getattr(cls, prop).deflate(value) + else: + # handle special operators + deflated_value, operator, prop = _handle_special_operators( + property_obj, key, value, operator, prop ) - output[db_property] = (operator, deflated_value) + return deflated_value, operator, prop - return output +def _initialize_filter_args_variables(cls, key: str): + current_class = cls + leaf_prop = None + operator = "=" + prop = key -def transform_in_operator_to_filter(operator, filter_key, filter_value, property_obj): - """ - Transform in operator to a cypher filter - Args: - operator (str): operator to transform - filter_key (str): filter key - filter_value (str): filter value - property_obj (object): property object - Returns: - tuple: operator, deflated_value - """ - if not isinstance(filter_value, tuple) and not isinstance(filter_value, list): - raise ValueError( - f"Value must be a tuple or list for IN operation {filter_key}={filter_value}" - ) - if isinstance(property_obj, ArrayProperty): - deflated_value = property_obj.deflate(filter_value) - operator = _SPECIAL_OPERATOR_ARRAY_IN - else: - deflated_value = [property_obj.deflate(v) for v in filter_value] + return current_class, leaf_prop, operator, prop - return operator, deflated_value +def _process_filter_key(cls, key: str) -> Tuple[Property, str, str]: + ( + current_class, + leaf_prop, + operator, + prop, + ) = _initialize_filter_args_variables(cls, key) -def transform_null_operator_to_filter(filter_key, filter_value): - """ - Transform null operator to a cypher filter - Args: - filter_key (str): filter key - filter_value (str): filter value - Returns: - tuple: operator, deflated_value - """ - if not isinstance(filter_value, bool): - raise ValueError(f"Value must be a bool for isnull operation on {filter_key}") - operator = "IS NULL" if filter_value else "IS NOT NULL" - deflated_value = None - return operator, deflated_value + for part in re.split(path_split_regex, key): + defined_props = current_class.defined_properties(rels=True) + if part in defined_props: + if isinstance( + defined_props[part], relationship_manager.RelationshipDefinition + ): + defined_props[part].lookup_node_class() + current_class = defined_props[part].definition["node_class"] + elif part in OPERATOR_TABLE: + operator = OPERATOR_TABLE[part] + prop, _ = prop.rsplit("__", 1) + continue + else: + raise ValueError( + f"No such property {part} on {cls.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation." + ) + leaf_prop = part + property_obj = getattr(current_class, leaf_prop) -def transform_regex_operator_to_filter( - operator, filter_key, filter_value, property_obj -): - """ - Transform regex operator to a cypher filter - Args: - operator (str): operator to transform - filter_key (str): filter key - filter_value (str): filter value - property_obj (object): property object - Returns: - tuple: operator, deflated_value - """ + return property_obj, operator, prop - deflated_value = property_obj.deflate(filter_value) - if not isinstance(deflated_value, str): - raise ValueError(f"Must be a string value for {filter_key}") - if operator in _STRING_REGEX_OPERATOR_TABLE.values(): - deflated_value = re.escape(deflated_value) - deflated_value = operator.format(deflated_value) - operator = _SPECIAL_OPERATOR_REGEX - return operator, deflated_value +def process_filter_args(cls, kwargs) -> Dict: + """ + loop through properties in filter parameters check they match class definition + deflate them and convert into something easy to generate cypher from + """ + output = {} -def transform_operator_to_filter(operator, filter_key, filter_value, property_obj): - if operator == _SPECIAL_OPERATOR_IN: - operator, deflated_value = transform_in_operator_to_filter( - operator=operator, - filter_key=filter_key, - filter_value=filter_value, - property_obj=property_obj, - ) - elif operator == _SPECIAL_OPERATOR_ISNULL: - operator, deflated_value = transform_null_operator_to_filter( - filter_key=filter_key, filter_value=filter_value - ) - elif operator in _REGEX_OPERATOR_TABLE.values(): - operator, deflated_value = transform_regex_operator_to_filter( - operator=operator, - filter_key=filter_key, - filter_value=filter_value, - property_obj=property_obj, + for key, value in kwargs.items(): + property_obj, operator, prop = _process_filter_key(cls, key) + deflated_value, operator, prop = _deflate_value( + cls, property_obj, key, value, operator, prop ) - else: - deflated_value = property_obj.deflate(filter_value) - return operator, deflated_value + # map property to correct property name in the database + db_property = prop + + output[db_property] = (operator, deflated_value) + return output def process_has_args(cls, kwargs): @@ -382,7 +360,7 @@ class QueryAST: where: TOptional[list] with_clause: TOptional[str] return_clause: TOptional[str] - order_by: TOptional[str] + order_by: TOptional[List[str]] skip: TOptional[int] limit: TOptional[int] result_class: TOptional[type] @@ -397,7 +375,7 @@ def __init__( where: TOptional[list] = None, with_clause: TOptional[str] = None, return_clause: TOptional[str] = None, - order_by: TOptional[str] = None, + order_by: TOptional[List[str]] = None, skip: TOptional[int] = None, limit: TOptional[int] = None, result_class: TOptional[type] = None, @@ -421,18 +399,14 @@ def __init__( class QueryBuilder: - def __init__( - self, node_set, with_subgraph: bool = False, subquery_context: bool = False - ) -> None: + def __init__(self, node_set, subquery_context: bool = False) -> None: self.node_set = node_set self._ast = QueryAST() - self._query_params = {} + self._query_params: Dict = {} self._place_holder_registry = {} - self._ident_count = 0 + self._ident_count: int = 0 self._node_counters = defaultdict(int) - self._with_subgraph: bool = with_subgraph self._subquery_context: bool = subquery_context - self._relation_identifiers: Dict[str, str] = {} def build_ast(self) -> "QueryBuilder": if hasattr(self.node_set, "relations_to_fetch"): @@ -477,19 +451,31 @@ def build_source(self, source) -> str: return self.build_node(source) raise ValueError("Unknown source type " + repr(source)) - def create_ident(self, relation_name: TOptional[str] = None) -> str: + def create_ident(self) -> str: self._ident_count += 1 - result = f"r{self._ident_count}" - if relation_name: - self._relation_identifiers[relation_name] = result - return result + return f"r{self._ident_count}" - def build_order_by(self, ident, source): + def build_order_by(self, ident: str, source: "NodeSet") -> None: if "?" in source.order_by_elements: self._ast.with_clause = f"{ident}, rand() as r" - self._ast.order_by = "r" + self._ast.order_by = ["r"] else: - self._ast.order_by = [f"{ident}.{p}" for p in source.order_by_elements] + order_by = [] + for elm in source.order_by_elements: + if "__" not in elm: + prop = elm.split(" ")[0] if " " in elm else elm + if prop not in source.source_class.defined_properties(rels=False): + raise ValueError( + f"No such property {prop} on {source.source_class.__name__}. " + f"Note that Neo4j internals like id or element_id are not allowed " + f"for use in this operation." + ) + order_by.append(f"{ident}.{elm}") + else: + path, prop = elm.rsplit("__", 1) + order_by_clause = self.lookup_query_variable(path) + order_by.append(f"{order_by_clause}.{prop}") + self._ast.order_by = order_by def build_traversal(self, traversal): """ @@ -527,9 +513,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: path: str = relation["path"] stmt: str = "" source_class_iterator = source_class - parts = path.split("__") - if self._with_subgraph: - subgraph = self._ast.subgraph + parts = re.split(path_split_regex, path) + subgraph = self._ast.subgraph rel_iterator: str = "" for index, part in enumerate(parts): relationship = getattr(source_class_iterator, part) @@ -569,8 +554,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: else: lhs_ident = stmt - rel_ident = self.create_ident(rel_iterator) - if self._with_subgraph and part not in self._ast.subgraph: + rel_ident = self.create_ident() + if part not in self._ast.subgraph: subgraph[part] = { "target": relationship.definition["node_class"], "children": {}, @@ -587,8 +572,7 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: relation_type=relationship.definition["relation_type"], ) source_class_iterator = relationship.definition["node_class"] - if self._with_subgraph: - subgraph = subgraph[part]["children"] + subgraph = subgraph[part]["children"] if relation.get("optional"): self._ast.optional_match.append(stmt) @@ -653,6 +637,41 @@ def _register_place_holder(self, key): self._place_holder_registry[key] = 1 return key + "_" + str(self._place_holder_registry[key]) + def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str]: + path, prop = prop.rsplit("__", 1) + ident = self.build_traversal_from_path( + {"path": path, "include_in_return": True}, + source_class, + ) + return ident, path, prop + + def _finalize_filter_statement(self, operator, ident, prop, val) -> str: + if operator in _UNARY_OPERATORS: + # unary operators do not have a parameter + statement = f"{ident}.{prop} {operator}" + else: + place_holder = self._register_place_holder(ident + "_" + prop) + if operator == _SPECIAL_OPERATOR_ARRAY_IN: + statement = operator.format( + ident=ident, + prop=prop, + val=f"${place_holder}", + ) + else: + statement = f"{ident}.{prop} {operator} ${place_holder}" + self._query_params[place_holder] = val + + return statement + + def _build_filter_statements(self, ident, filters, target, source_class): + for prop, op_and_val in filters.items(): + path = None + if "__" in prop: + ident, path, prop = self._parse_path(source_class, prop) + operator, val = op_and_val + statement = self._finalize_filter_statement(operator, ident, prop, val) + target.append(statement) + def _parse_q_filters(self, ident, q, source_class): target = [] for child in q.children: @@ -664,23 +683,7 @@ def _parse_q_filters(self, ident, q, source_class): else: kwargs = {child[0]: child[1]} filters = process_filter_args(source_class, kwargs) - for prop, op_and_val in filters.items(): - operator, val = op_and_val - if operator in _UNARY_OPERATORS: - # unary operators do not have a parameter - statement = f"{ident}.{prop} {operator}" - else: - place_holder = self._register_place_holder(ident + "_" + prop) - if operator == _SPECIAL_OPERATOR_ARRAY_IN: - statement = operator.format( - ident=ident, - prop=prop, - val=f"${place_holder}", - ) - else: - statement = f"{ident}.{prop} {operator} ${place_holder}" - self._query_params[place_holder] = val - target.append(statement) + self._build_filter_statements(ident, filters, target, source_class) ret = f" {q.connector} ".join(target) if q.negated: ret = f"NOT ({ret})" @@ -719,6 +722,35 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None): self._ast.where.append(" AND ".join(stmts)) + def lookup_query_variable( + self, path: str, return_relation: bool = False + ) -> TOptional[str]: + """Retrieve the variable name generated internally for the given traversal path.""" + subgraph = self._ast.subgraph + if not subgraph: + return None + traversals = re.split(path_split_regex, path) + if len(traversals) == 0: + raise ValueError("Can only lookup traversal variables") + if traversals[0] not in subgraph: + return None + subgraph = subgraph[traversals[0]] + variable_to_return = None + last_property = traversals[-1] + for part in traversals: + if part in subgraph["children"]: + subgraph = subgraph["children"][part] + elif part == last_property: + # if last part of prop is the last traversal + # we are safe to lookup the variable from the query + if return_relation: + variable_to_return = f"{subgraph['rel_variable_name']}" + else: + variable_to_return = f"{subgraph['variable_name']}" + else: + break + return variable_to_return + def build_query(self) -> str: query: str = "" @@ -754,7 +786,9 @@ def build_query(self) -> str: if type(source) is str: injected_vars.append(f"{source} AS {name}") elif isinstance(source, RelationNameResolver): - internal_name = self._relation_identifiers.get(source.relation) + internal_name = self.lookup_query_variable( + source.relation, return_relation=True + ) if not internal_name: raise ValueError( f"Unable to resolve variable name for relation {source.relation}." @@ -878,6 +912,7 @@ class BaseSet: """ query_cls = QueryBuilder + source_class: StructuredNode def all(self, lazy=False): """ @@ -1027,6 +1062,7 @@ def __init__(self, source) -> None: self.filters: List = [] self.q_filters = Q() + self.order_by_elements: List = [] # used by has() self.must_match: Dict = {} @@ -1177,14 +1213,10 @@ def order_by(self, *props): else: desc = False - if prop not in self.source_class.defined_properties(rels=False): - raise ValueError( - f"No such property {prop} on {self.source_class.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation." - ) - - property_obj = getattr(self.source_class, prop) - if isinstance(property_obj, AliasProperty): - prop = property_obj.aliased_to() + if prop in self.source_class.defined_properties(rels=False): + property_obj = getattr(self.source_class, prop) + if isinstance(property_obj, AliasProperty): + prop = property_obj.aliased_to() self.order_by_elements.append(prop + (" DESC" if desc else "")) @@ -1298,7 +1330,7 @@ def resolve_subgraph(self) -> list: "You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead." ) results: list = [] - qbuilder = self.query_cls(self, with_subgraph=True) + qbuilder = self.query_cls(self) qbuilder.build_ast() all_nodes = qbuilder._execute(dict_output=True) other_nodes = {} diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 0659f468..1d90b625 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -310,7 +310,7 @@ async def test_order_by(): ns = ns.order_by("?") qb = await AsyncQueryBuilder(ns).build_ast() assert qb._ast.with_clause == "coffee, rand() as r" - assert qb._ast.order_by == "r" + assert qb._ast.order_by == ["r"] with raises( ValueError, @@ -544,8 +544,32 @@ async def test_traversal_filter_left_hand_statement(): assert lidl in lidl_supplier +@mark_async_test +async def test_filter_with_traversal(): + # Clean DB before we start anything... + await adb.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = await Species(name="Arabica").save() + robusta = await Species(name="Robusta").save() + nescafe = await Coffee(name="Nescafe", price=11).save() + nescafe_gold = await Coffee(name="Nescafe Gold", price=99).save() + tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + await nescafe.suppliers.connect(tesco) + await nescafe_gold.suppliers.connect(tesco) + await nescafe.species.connect(arabica) + await nescafe_gold.species.connect(robusta) + + results = await Coffee.nodes.filter(species__name="Arabica").all() + assert len(results) == 1 + assert len(results[0]) == 3 + assert results[0][0] == nescafe + + @mark_async_test async def test_fetch_relations(): + # Clean DB before we start anything... + await adb.cypher_query("MATCH (n) DETACH DELETE n") + arabica = await Species(name="Arabica").save() robusta = await Species(name="Robusta").save() nescafe = await Coffee(name="Nescafe 1000", price=99).save() @@ -601,6 +625,30 @@ async def test_fetch_relations(): ) +@mark_async_test +async def test_traverse_and_order_by(): + # Clean DB before we start anything... + await adb.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = await Species(name="Arabica").save() + robusta = await Species(name="Robusta").save() + nescafe = await Coffee(name="Nescafe", price=99).save() + nescafe_gold = await Coffee(name="Nescafe Gold", price=110).save() + tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + await nescafe.suppliers.connect(tesco) + await nescafe_gold.suppliers.connect(tesco) + await nescafe.species.connect(arabica) + await nescafe_gold.species.connect(robusta) + + results = ( + await Species.nodes.fetch_relations("coffees").order_by("-coffees__price").all() + ) + assert len(results) == 2 + assert len(results[0]) == 3 # 2 nodes and 1 relation + assert results[0][0] == robusta + assert results[1][0] == arabica + + @mark_async_test async def test_annotate_and_collect(): # Clean DB before we start anything... diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 3976c92a..f9beafa6 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -306,7 +306,7 @@ def test_order_by(): ns = ns.order_by("?") qb = QueryBuilder(ns).build_ast() assert qb._ast.with_clause == "coffee, rand() as r" - assert qb._ast.order_by == "r" + assert qb._ast.order_by == ["r"] with raises( ValueError, @@ -538,8 +538,32 @@ def test_traversal_filter_left_hand_statement(): assert lidl in lidl_supplier +@mark_sync_test +def test_filter_with_traversal(): + # Clean DB before we start anything... + db.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = Species(name="Arabica").save() + robusta = Species(name="Robusta").save() + nescafe = Coffee(name="Nescafe", price=11).save() + nescafe_gold = Coffee(name="Nescafe Gold", price=99).save() + tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + nescafe.suppliers.connect(tesco) + nescafe_gold.suppliers.connect(tesco) + nescafe.species.connect(arabica) + nescafe_gold.species.connect(robusta) + + results = Coffee.nodes.filter(species__name="Arabica").all() + assert len(results) == 1 + assert len(results[0]) == 3 + assert results[0][0] == nescafe + + @mark_sync_test def test_fetch_relations(): + # Clean DB before we start anything... + db.cypher_query("MATCH (n) DETACH DELETE n") + arabica = Species(name="Arabica").save() robusta = Species(name="Robusta").save() nescafe = Coffee(name="Nescafe 1000", price=99).save() @@ -595,6 +619,28 @@ def test_fetch_relations(): ) +@mark_sync_test +def test_traverse_and_order_by(): + # Clean DB before we start anything... + db.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = Species(name="Arabica").save() + robusta = Species(name="Robusta").save() + nescafe = Coffee(name="Nescafe", price=99).save() + nescafe_gold = Coffee(name="Nescafe Gold", price=110).save() + tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + nescafe.suppliers.connect(tesco) + nescafe_gold.suppliers.connect(tesco) + nescafe.species.connect(arabica) + nescafe_gold.species.connect(robusta) + + results = Species.nodes.fetch_relations("coffees").order_by("-coffees__price").all() + assert len(results) == 2 + assert len(results[0]) == 3 # 2 nodes and 1 relation + assert results[0][0] == robusta + assert results[1][0] == arabica + + @mark_sync_test def test_annotate_and_collect(): # Clean DB before we start anything... From 272393fac6363dfdf3eff194598a9ca1c6958ad7 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Tue, 8 Oct 2024 15:31:30 +0200 Subject: [PATCH 17/42] Added missing db property resolution call --- neomodel/async_/match.py | 34 +++++++++++++++++++++------------- neomodel/sync_/match.py | 34 +++++++++++++++++++++------------- 2 files changed, 42 insertions(+), 26 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index fa7e0bf6..6bd81ff6 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -318,7 +318,6 @@ def process_filter_args(cls, kwargs) -> Dict: deflated_value, operator, prop = _deflate_value( cls, property_obj, key, value, operator, prop ) - # map property to correct property name in the database db_property = prop @@ -357,7 +356,7 @@ def process_has_args(cls, kwargs): class QueryAST: match: List[str] optional_match: List[str] - where: TOptional[list] + where: List[str] with_clause: TOptional[str] return_clause: TOptional[str] order_by: TOptional[List[str]] @@ -372,7 +371,7 @@ def __init__( self, match: TOptional[List[str]] = None, optional_match: TOptional[List[str]] = None, - where: TOptional[list] = None, + where: TOptional[List[str]] = None, with_clause: TOptional[str] = None, return_clause: TOptional[str] = None, order_by: TOptional[List[str]] = None, @@ -403,7 +402,7 @@ def __init__(self, node_set, subquery_context: bool = False) -> None: self.node_set = node_set self._ast = QueryAST() self._query_params: Dict = {} - self._place_holder_registry = {} + self._place_holder_registry: Dict = {} self._ident_count: int = 0 self._node_counters = defaultdict(int) self._subquery_context: bool = subquery_context @@ -477,7 +476,7 @@ def build_order_by(self, ident: str, source: "AsyncNodeSet") -> None: order_by.append(f"{order_by_clause}.{prop}") self._ast.order_by = order_by - async def build_traversal(self, traversal): + async def build_traversal(self, traversal) -> str: """ traverse a relationship from a node to a set of nodes """ @@ -630,7 +629,7 @@ def build_additional_match(self, ident, node_set): else: raise ValueError("Expecting dict got: " + repr(val)) - def _register_place_holder(self, key): + def _register_place_holder(self, key: str) -> str: if key in self._place_holder_registry: self._place_holder_registry[key] += 1 else: @@ -645,7 +644,9 @@ def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str]: ) return ident, path, prop - def _finalize_filter_statement(self, operator, ident, prop, val) -> str: + def _finalize_filter_statement( + self, operator: str, ident: str, prop: str, val: Any + ) -> str: if operator in _UNARY_OPERATORS: # unary operators do not have a parameter statement = f"{ident}.{prop} {operator}" @@ -663,16 +664,21 @@ def _finalize_filter_statement(self, operator, ident, prop, val) -> str: return statement - def _build_filter_statements(self, ident, filters, target, source_class): + def _build_filter_statements( + self, ident: str, filters, target: List[str], source_class + ) -> None: for prop, op_and_val in filters.items(): path = None if "__" in prop: ident, path, prop = self._parse_path(source_class, prop) operator, val = op_and_val + prop = source_class.defined_properties(rels=False)[ + prop + ].get_db_property_name(prop) statement = self._finalize_filter_statement(operator, ident, prop, val) target.append(statement) - def _parse_q_filters(self, ident, q, source_class): + def _parse_q_filters(self, ident, q, source_class) -> str: target = [] for child in q.children: if isinstance(child, QBase): @@ -689,14 +695,16 @@ def _parse_q_filters(self, ident, q, source_class): ret = f"NOT ({ret})" return ret - def build_where_stmt(self, ident, filters, q_filters=None, source_class=None): + def build_where_stmt( + self, ident: str, filters, q_filters=None, source_class=None + ) -> None: """ construct a where statement from some filters """ if q_filters is not None: - stmts = self._parse_q_filters(ident, q_filters, source_class) - if stmts: - self._ast.where.append(stmts) + stmt = self._parse_q_filters(ident, q_filters, source_class) + if stmt: + self._ast.where.append(stmt) else: stmts = [] for row in filters: diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index a1d311fd..3a39d7ae 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -318,7 +318,6 @@ def process_filter_args(cls, kwargs) -> Dict: deflated_value, operator, prop = _deflate_value( cls, property_obj, key, value, operator, prop ) - # map property to correct property name in the database db_property = prop @@ -357,7 +356,7 @@ def process_has_args(cls, kwargs): class QueryAST: match: List[str] optional_match: List[str] - where: TOptional[list] + where: List[str] with_clause: TOptional[str] return_clause: TOptional[str] order_by: TOptional[List[str]] @@ -372,7 +371,7 @@ def __init__( self, match: TOptional[List[str]] = None, optional_match: TOptional[List[str]] = None, - where: TOptional[list] = None, + where: TOptional[List[str]] = None, with_clause: TOptional[str] = None, return_clause: TOptional[str] = None, order_by: TOptional[List[str]] = None, @@ -403,7 +402,7 @@ def __init__(self, node_set, subquery_context: bool = False) -> None: self.node_set = node_set self._ast = QueryAST() self._query_params: Dict = {} - self._place_holder_registry = {} + self._place_holder_registry: Dict = {} self._ident_count: int = 0 self._node_counters = defaultdict(int) self._subquery_context: bool = subquery_context @@ -477,7 +476,7 @@ def build_order_by(self, ident: str, source: "NodeSet") -> None: order_by.append(f"{order_by_clause}.{prop}") self._ast.order_by = order_by - def build_traversal(self, traversal): + def build_traversal(self, traversal) -> str: """ traverse a relationship from a node to a set of nodes """ @@ -630,7 +629,7 @@ def build_additional_match(self, ident, node_set): else: raise ValueError("Expecting dict got: " + repr(val)) - def _register_place_holder(self, key): + def _register_place_holder(self, key: str) -> str: if key in self._place_holder_registry: self._place_holder_registry[key] += 1 else: @@ -645,7 +644,9 @@ def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str]: ) return ident, path, prop - def _finalize_filter_statement(self, operator, ident, prop, val) -> str: + def _finalize_filter_statement( + self, operator: str, ident: str, prop: str, val: Any + ) -> str: if operator in _UNARY_OPERATORS: # unary operators do not have a parameter statement = f"{ident}.{prop} {operator}" @@ -663,16 +664,21 @@ def _finalize_filter_statement(self, operator, ident, prop, val) -> str: return statement - def _build_filter_statements(self, ident, filters, target, source_class): + def _build_filter_statements( + self, ident: str, filters, target: List[str], source_class + ) -> None: for prop, op_and_val in filters.items(): path = None if "__" in prop: ident, path, prop = self._parse_path(source_class, prop) operator, val = op_and_val + prop = source_class.defined_properties(rels=False)[ + prop + ].get_db_property_name(prop) statement = self._finalize_filter_statement(operator, ident, prop, val) target.append(statement) - def _parse_q_filters(self, ident, q, source_class): + def _parse_q_filters(self, ident, q, source_class) -> str: target = [] for child in q.children: if isinstance(child, QBase): @@ -689,14 +695,16 @@ def _parse_q_filters(self, ident, q, source_class): ret = f"NOT ({ret})" return ret - def build_where_stmt(self, ident, filters, q_filters=None, source_class=None): + def build_where_stmt( + self, ident: str, filters, q_filters=None, source_class=None + ) -> None: """ construct a where statement from some filters """ if q_filters is not None: - stmts = self._parse_q_filters(ident, q_filters, source_class) - if stmts: - self._ast.where.append(stmts) + stmt = self._parse_q_filters(ident, q_filters, source_class) + if stmt: + self._ast.where.append(stmt) else: stmts = [] for row in filters: From 57c18a56238a5ec263fefac797734240877378c3 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Tue, 8 Oct 2024 15:51:14 +0200 Subject: [PATCH 18/42] Fixed unit test --- test/async_/test_match_api.py | 6 +++--- test/sync_/test_match_api.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 1d90b625..09b3a3e8 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -287,8 +287,8 @@ async def test_contains(): @mark_async_test async def test_order_by(): - for c in await Coffee.nodes: - await c.delete() + # Clean DB before we start anything... + await adb.cypher_query("MATCH (n) DETACH DELETE n") c1 = await Coffee(name="Icelands finest", price=5).save() c2 = await Coffee(name="Britains finest", price=10).save() @@ -316,7 +316,7 @@ async def test_order_by(): ValueError, match=r".*Neo4j internals like id or element_id are not allowed for use in this operation.", ): - await Coffee.nodes.order_by("id") + await Coffee.nodes.order_by("id").all() # Test order by on a relationship l = await Supplier(name="lidl2").save() diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index f9beafa6..942b8482 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -283,8 +283,8 @@ def test_contains(): @mark_sync_test def test_order_by(): - for c in Coffee.nodes: - c.delete() + # Clean DB before we start anything... + db.cypher_query("MATCH (n) DETACH DELETE n") c1 = Coffee(name="Icelands finest", price=5).save() c2 = Coffee(name="Britains finest", price=10).save() @@ -312,7 +312,7 @@ def test_order_by(): ValueError, match=r".*Neo4j internals like id or element_id are not allowed for use in this operation.", ): - Coffee.nodes.order_by("id") + Coffee.nodes.order_by("id").all() # Test order by on a relationship l = Supplier(name="lidl2").save() From 6cd0b929c16297e9063655f5765e32943bafe0e8 Mon Sep 17 00:00:00 2001 From: Daniyar Irishev Date: Mon, 14 Oct 2024 08:20:37 +0900 Subject: [PATCH 19/42] feat: add ensure_ascii kwarg to JSONProperty --- neomodel/properties.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/neomodel/properties.py b/neomodel/properties.py index d4a91885..52a3a243 100644 --- a/neomodel/properties.py +++ b/neomodel/properties.py @@ -529,8 +529,9 @@ class JSONProperty(Property): The structure will be inflated when a node is retrieved. """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, ensure_ascii=False, *args, **kwargs): + self.ensure_ascii = ensure_ascii + super(JSONProperty, self).__init__(*args, **kwargs) @validator def inflate(self, value): @@ -538,7 +539,7 @@ def inflate(self, value): @validator def deflate(self, value): - return json.dumps(value) + return json.dumps(value, ensure_ascii=self.ensure_ascii) class AliasProperty(property, Property): From 62bdf0a0d058bcfe7f147ed3cda2f5e6eaab2850 Mon Sep 17 00:00:00 2001 From: Daniyar Irishev Date: Mon, 14 Oct 2024 08:25:57 +0900 Subject: [PATCH 20/42] fix: set default ensure_ascii to True --- neomodel/properties.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neomodel/properties.py b/neomodel/properties.py index 52a3a243..352a8899 100644 --- a/neomodel/properties.py +++ b/neomodel/properties.py @@ -529,7 +529,7 @@ class JSONProperty(Property): The structure will be inflated when a node is retrieved. """ - def __init__(self, ensure_ascii=False, *args, **kwargs): + def __init__(self, ensure_ascii=True, *args, **kwargs): self.ensure_ascii = ensure_ascii super(JSONProperty, self).__init__(*args, **kwargs) From 8f34323250d023652199820078627e75f17d1f95 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Mon, 14 Oct 2024 13:29:35 +0200 Subject: [PATCH 21/42] Allow raw cypher statement as argument of order_by() --- neomodel/async_/match.py | 30 ++++++++++++++++++++++++++++++ neomodel/sync_/match.py | 30 ++++++++++++++++++++++++++++++ test/async_/test_match_api.py | 29 +++++++++++++++++++++++++++++ test/sync_/test_match_api.py | 29 +++++++++++++++++++++++++++++ 4 files changed, 118 insertions(+) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 6bd81ff6..ecb5263f 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -1,5 +1,6 @@ import inspect import re +import string from collections import defaultdict from dataclasses import dataclass from typing import Any, Dict, List @@ -14,6 +15,8 @@ from neomodel.properties import AliasProperty, ArrayProperty, Property from neomodel.util import INCOMING, OUTGOING +CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR = re.compile(r"(?i:MERGE|CREATE|DELETE|DETACH)") + def _rel_helper( lhs, @@ -461,6 +464,9 @@ def build_order_by(self, ident: str, source: "AsyncNodeSet") -> None: else: order_by = [] for elm in source.order_by_elements: + if isinstance(elm, RawCypher): + order_by.append(elm.render({"n": ident})) + continue if "__" not in elm: prop = elm.split(" ")[0] if " " in elm else elm if prop not in source.source_class.defined_properties(rels=False): @@ -895,6 +901,7 @@ async def _execute(self, lazy: bool = False, dict_output: bool = False): for item in self._ast.additional_return ] query = self.build_query() + print(query) results, prop_names = await adb.cypher_query( query, self._query_params, resolve_objects=True ) @@ -1051,6 +1058,26 @@ class RelationNameResolver: relation: str +@dataclass +class RawCypher: + """Helper to inject raw cypher statement. + + Can be used in order_by() call for example. + + """ + + statement: str + + def __post_init__(self): + if CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR.search(self.statement): + raise ValueError( + "RawCypher: Do not include any action that has side effect" + ) + + def render(self, context: Dict) -> str: + return string.Template(self.statement).substitute(context) + + class AsyncNodeSet(AsyncBaseSet): """ A class representing as set of nodes matching common query parameters @@ -1216,6 +1243,9 @@ def order_by(self, *props): self.order_by_elements.append("?") else: for prop in props: + if isinstance(prop, RawCypher): + self.order_by_elements.append(prop) + continue prop = prop.strip() if prop.startswith("-"): prop = prop[1:] diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 3a39d7ae..64696959 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -1,5 +1,6 @@ import inspect import re +import string from collections import defaultdict from dataclasses import dataclass from typing import Any, Dict, List @@ -14,6 +15,8 @@ from neomodel.sync_.relationship import StructuredRel from neomodel.util import INCOMING, OUTGOING +CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR = re.compile(r"(?i:MERGE|CREATE|DELETE|DETACH)") + def _rel_helper( lhs, @@ -461,6 +464,9 @@ def build_order_by(self, ident: str, source: "NodeSet") -> None: else: order_by = [] for elm in source.order_by_elements: + if isinstance(elm, RawCypher): + order_by.append(elm.render({"n": ident})) + continue if "__" not in elm: prop = elm.split(" ")[0] if " " in elm else elm if prop not in source.source_class.defined_properties(rels=False): @@ -893,6 +899,7 @@ def _execute(self, lazy: bool = False, dict_output: bool = False): for item in self._ast.additional_return ] query = self.build_query() + print(query) results, prop_names = db.cypher_query( query, self._query_params, resolve_objects=True ) @@ -1049,6 +1056,26 @@ class RelationNameResolver: relation: str +@dataclass +class RawCypher: + """Helper to inject raw cypher statement. + + Can be used in order_by() call for example. + + """ + + statement: str + + def __post_init__(self): + if CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR.search(self.statement): + raise ValueError( + "RawCypher: Do not include any action that has side effect" + ) + + def render(self, context: Dict) -> str: + return string.Template(self.statement).substitute(context) + + class NodeSet(BaseSet): """ A class representing as set of nodes matching common query parameters @@ -1214,6 +1241,9 @@ def order_by(self, *props): self.order_by_elements.append("?") else: for prop in props: + if isinstance(prop, RawCypher): + self.order_by_elements.append(prop) + continue prop = prop.strip() if prop.startswith("-"): prop = prop[1:] diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 09b3a3e8..7b11b928 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -26,6 +26,7 @@ Collect, Last, Optional, + RawCypher, RelationNameResolver, ) from neomodel.exceptions import MultipleNodesReturned, RelationshipClassNotDefined @@ -84,6 +85,11 @@ class PersonX(AsyncStructuredNode): city = AsyncRelationshipTo(CityX, "LIVES_IN") +class SoftwareDependency(AsyncStructuredNode): + name = StringProperty(required=True) + version = StringProperty(required=True) + + @mark_async_test async def test_filter_exclude_via_labels(): await Coffee(name="Java", price=99).save() @@ -330,6 +336,29 @@ async def test_order_by(): assert ordered_n[2] == c3 +@mark_async_test +async def test_order_by_rawcypher(): + # Clean DB before we start anything... + await adb.cypher_query("MATCH (n) DETACH DELETE n") + + d1 = await SoftwareDependency(name="Package1", version="1.0.0").save() + d2 = await SoftwareDependency(name="Package2", version="1.4.0").save() + d3 = await SoftwareDependency(name="Package3", version="2.5.5").save() + + assert ( + await SoftwareDependency.nodes.order_by( + RawCypher("toInteger(split($n.version, '.')[0]) DESC"), + ).all() + )[0] == d3 + + with raises( + ValueError, match=r"RawCypher: Do not include any action that has side effect" + ): + SoftwareDependency.nodes.order_by( + RawCypher("DETACH DELETE $n"), + ) + + @mark_async_test async def test_extra_filters(): for c in await Coffee.nodes: diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 942b8482..39270124 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -26,6 +26,7 @@ NodeSet, Optional, QueryBuilder, + RawCypher, RelationNameResolver, Traversal, ) @@ -82,6 +83,11 @@ class PersonX(StructuredNode): city = RelationshipTo(CityX, "LIVES_IN") +class SoftwareDependency(StructuredNode): + name = StringProperty(required=True) + version = StringProperty(required=True) + + @mark_sync_test def test_filter_exclude_via_labels(): Coffee(name="Java", price=99).save() @@ -326,6 +332,29 @@ def test_order_by(): assert ordered_n[2] == c3 +@mark_sync_test +def test_order_by_rawcypher(): + # Clean DB before we start anything... + db.cypher_query("MATCH (n) DETACH DELETE n") + + d1 = SoftwareDependency(name="Package1", version="1.0.0").save() + d2 = SoftwareDependency(name="Package2", version="1.4.0").save() + d3 = SoftwareDependency(name="Package3", version="2.5.5").save() + + assert ( + SoftwareDependency.nodes.order_by( + RawCypher("toInteger(split($n.version, '.')[0]) DESC"), + ).all() + )[0] == d3 + + with raises( + ValueError, match=r"RawCypher: Do not include any action that has side effect" + ): + SoftwareDependency.nodes.order_by( + RawCypher("DETACH DELETE $n"), + ) + + @mark_sync_test def test_extra_filters(): for c in Coffee.nodes: From ab67d981a8dde9f72b5e5d0b6e1882b37bf264ea Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Mon, 14 Oct 2024 13:41:55 +0200 Subject: [PATCH 22/42] Removed python 3.7 from test matrix --- .github/workflows/integration-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 032a5868..25b26cda 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.12", "3.11", "3.10", "3.9", "3.8", "3.7"] + python-version: ["3.12", "3.11", "3.10", "3.9", "3.8"] neo4j-version: ["community", "enterprise", "5.5-enterprise", "4.4-enterprise", "4.4-community"] steps: From c16fb1a3f00aa1920bb34cec93af22150d62dbb6 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Mon, 14 Oct 2024 15:42:38 +0200 Subject: [PATCH 23/42] Possibility to filter and order by relation property. --- neomodel/async_/match.py | 115 ++++++++++++++++++++++----------- neomodel/sync_/match.py | 117 +++++++++++++++++++++++----------- test/async_/test_match_api.py | 57 ++++++++++++++++- test/sync_/test_match_api.py | 53 ++++++++++++++- 4 files changed, 262 insertions(+), 80 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index ecb5263f..80fb0ee9 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -201,7 +201,7 @@ def _rel_merge_helper( # add all regex operators OPERATOR_TABLE.update(_REGEX_OPERATOR_TABLE) -path_split_regex = re.compile(r"__(?!_)") +path_split_regex = re.compile(r"__(?!_)|\|") def install_traversals(cls, node_set): @@ -271,29 +271,38 @@ def _deflate_value( def _initialize_filter_args_variables(cls, key: str): current_class = cls + current_rel_model = None leaf_prop = None operator = "=" + is_rel_property = "|" in key prop = key - return current_class, leaf_prop, operator, prop + return current_class, current_rel_model, leaf_prop, operator, is_rel_property, prop def _process_filter_key(cls, key: str) -> Tuple[Property, str, str]: ( current_class, + current_rel_model, leaf_prop, operator, + is_rel_property, prop, ) = _initialize_filter_args_variables(cls, key) for part in re.split(path_split_regex, key): defined_props = current_class.defined_properties(rels=True) + # update defined props dictionary with relationship properties if + # we are filtering by property + if is_rel_property and current_rel_model: + defined_props.update(current_rel_model.defined_properties(rels=True)) if part in defined_props: if isinstance( defined_props[part], relationship_manager.AsyncRelationshipDefinition ): defined_props[part].lookup_node_class() current_class = defined_props[part].definition["node_class"] + current_rel_model = defined_props[part].definition["model"] elif part in OPERATOR_TABLE: operator = OPERATOR_TABLE[part] prop, _ = prop.rsplit("__", 1) @@ -304,7 +313,10 @@ def _process_filter_key(cls, key: str) -> Tuple[Property, str, str]: ) leaf_prop = part - property_obj = getattr(current_class, leaf_prop) + if is_rel_property and current_rel_model: + property_obj = getattr(current_rel_model, leaf_prop) + else: + property_obj = getattr(current_class, leaf_prop) return property_obj, operator, prop @@ -467,7 +479,8 @@ def build_order_by(self, ident: str, source: "AsyncNodeSet") -> None: if isinstance(elm, RawCypher): order_by.append(elm.render({"n": ident})) continue - if "__" not in elm: + is_rel_property = "|" in elm + if "__" not in elm and not is_rel_property: prop = elm.split(" ")[0] if " " in elm else elm if prop not in source.source_class.defined_properties(rels=False): raise ValueError( @@ -477,8 +490,10 @@ def build_order_by(self, ident: str, source: "AsyncNodeSet") -> None: ) order_by.append(f"{ident}.{elm}") else: - path, prop = elm.rsplit("__", 1) - order_by_clause = self.lookup_query_variable(path) + path, prop = elm.rsplit("__" if not is_rel_property else "|", 1) + order_by_clause = self.lookup_query_variable( + path, return_relation=is_rel_property + ) order_by.append(f"{order_by_clause}.{prop}") self._ast.order_by = order_by @@ -521,6 +536,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: parts = re.split(path_split_regex, path) subgraph = self._ast.subgraph rel_iterator: str = "" + already_present = False + existing_rhs_name = "" for index, part in enumerate(parts): relationship = getattr(source_class_iterator, part) if rel_iterator: @@ -529,19 +546,6 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: # build source if "node_class" not in relationship.definition: relationship.lookup_node_class() - rhs_label = relationship.definition["node_class"].__label__ - rel_reference = f'{relationship.definition["node_class"]}_{part}' - self._node_counters[rel_reference] += 1 - if index + 1 == len(parts) and "alias" in relation: - # If an alias is defined, use it to store the last hop in the path - rhs_name = relation["alias"] - else: - rhs_name = ( - f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}" - ) - rhs_ident = f"{rhs_name}:{rhs_label}" - if relation["include_in_return"]: - self._additional_return(rhs_name) if not stmt: lhs_label = source_class_iterator.__label__ lhs_name = lhs_label.lower() @@ -559,15 +563,37 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: else: lhs_ident = stmt + already_present = part in subgraph rel_ident = self.create_ident() - if part not in self._ast.subgraph: + rhs_label = relationship.definition["node_class"].__label__ + if relation.get("relation_filtering"): + rhs_name = rel_ident + else: + rel_reference = f'{relationship.definition["node_class"]}_{part}' + self._node_counters[rel_reference] += 1 + if index + 1 == len(parts) and "alias" in relation: + # If an alias is defined, use it to store the last hop in the path + rhs_name = relation["alias"] + else: + rhs_name = f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}" + rhs_ident = f"{rhs_name}:{rhs_label}" + if relation["include_in_return"] and not already_present: + self._additional_return(rhs_name) + + if not already_present: subgraph[part] = { "target": relationship.definition["node_class"], "children": {}, "variable_name": rhs_name, "rel_variable_name": rel_ident, } - if relation["include_in_return"]: + else: + existing_rhs_name = subgraph[part][ + "rel_variable_name" + if relation["relation_filtering"] + else "variable_name" + ] + if relation["include_in_return"] and not already_present: self._additional_return(rel_ident) stmt = _rel_helper( lhs=lhs_ident, @@ -579,11 +605,14 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: source_class_iterator = relationship.definition["node_class"] subgraph = subgraph[part]["children"] - if relation.get("optional"): - self._ast.optional_match.append(stmt) - else: - self._ast.match.append(stmt) - return rhs_name + if not already_present: + if relation.get("optional"): + self._ast.optional_match.append(stmt) + else: + self._ast.match.append(stmt) + return rhs_name + + return existing_rhs_name async def build_node(self, node): ident = node.__class__.__name__.lower() @@ -643,11 +672,21 @@ def _register_place_holder(self, key: str) -> str: return key + "_" + str(self._place_holder_registry[key]) def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str]: - path, prop = prop.rsplit("__", 1) - ident = self.build_traversal_from_path( - {"path": path, "include_in_return": True}, - source_class, - ) + is_rel_filter = "|" in prop + ident = self.lookup_query_variable(prop) + if is_rel_filter: + path, prop = prop.rsplit("|", 1) + else: + path, prop = prop.rsplit("__", 1) + if not ident: + ident = self.build_traversal_from_path( + { + "path": path, + "include_in_return": True, + "relation_filtering": is_rel_filter, + }, + source_class, + ) return ident, path, prop def _finalize_filter_statement( @@ -675,12 +714,14 @@ def _build_filter_statements( ) -> None: for prop, op_and_val in filters.items(): path = None - if "__" in prop: + is_rel_filter = "|" in prop + if "__" in prop or is_rel_filter: ident, path, prop = self._parse_path(source_class, prop) operator, val = op_and_val - prop = source_class.defined_properties(rels=False)[ - prop - ].get_db_property_name(prop) + if not is_rel_filter: + prop = source_class.defined_properties(rels=False)[ + prop + ].get_db_property_name(prop) statement = self._finalize_filter_statement(operator, ident, prop, val) target.append(statement) @@ -743,6 +784,7 @@ def lookup_query_variable( subgraph = self._ast.subgraph if not subgraph: return None + is_rel_property = "|" in path traversals = re.split(path_split_regex, path) if len(traversals) == 0: raise ValueError("Can only lookup traversal variables") @@ -757,7 +799,7 @@ def lookup_query_variable( elif part == last_property: # if last part of prop is the last traversal # we are safe to lookup the variable from the query - if return_relation: + if is_rel_property or return_relation: variable_to_return = f"{subgraph['rel_variable_name']}" else: variable_to_return = f"{subgraph['variable_name']}" @@ -901,7 +943,6 @@ async def _execute(self, lazy: bool = False, dict_output: bool = False): for item in self._ast.additional_return ] query = self.build_query() - print(query) results, prop_names = await adb.cypher_query( query, self._query_params, resolve_objects=True ) diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 64696959..cf344b5a 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -201,7 +201,7 @@ def _rel_merge_helper( # add all regex operators OPERATOR_TABLE.update(_REGEX_OPERATOR_TABLE) -path_split_regex = re.compile(r"__(?!_)") +path_split_regex = re.compile(r"__(?!_)|\|") def install_traversals(cls, node_set): @@ -271,29 +271,38 @@ def _deflate_value( def _initialize_filter_args_variables(cls, key: str): current_class = cls + current_rel_model = None leaf_prop = None operator = "=" + is_rel_property = "|" in key prop = key - return current_class, leaf_prop, operator, prop + return current_class, current_rel_model, leaf_prop, operator, is_rel_property, prop def _process_filter_key(cls, key: str) -> Tuple[Property, str, str]: ( current_class, + current_rel_model, leaf_prop, operator, + is_rel_property, prop, ) = _initialize_filter_args_variables(cls, key) for part in re.split(path_split_regex, key): defined_props = current_class.defined_properties(rels=True) + # update defined props dictionary with relationship properties if + # we are filtering by property + if is_rel_property and current_rel_model: + defined_props.update(current_rel_model.defined_properties(rels=True)) if part in defined_props: if isinstance( defined_props[part], relationship_manager.RelationshipDefinition ): defined_props[part].lookup_node_class() current_class = defined_props[part].definition["node_class"] + current_rel_model = defined_props[part].definition["model"] elif part in OPERATOR_TABLE: operator = OPERATOR_TABLE[part] prop, _ = prop.rsplit("__", 1) @@ -304,7 +313,10 @@ def _process_filter_key(cls, key: str) -> Tuple[Property, str, str]: ) leaf_prop = part - property_obj = getattr(current_class, leaf_prop) + if is_rel_property and current_rel_model: + property_obj = getattr(current_rel_model, leaf_prop) + else: + property_obj = getattr(current_class, leaf_prop) return property_obj, operator, prop @@ -467,7 +479,8 @@ def build_order_by(self, ident: str, source: "NodeSet") -> None: if isinstance(elm, RawCypher): order_by.append(elm.render({"n": ident})) continue - if "__" not in elm: + is_rel_property = "|" in elm + if "__" not in elm and not is_rel_property: prop = elm.split(" ")[0] if " " in elm else elm if prop not in source.source_class.defined_properties(rels=False): raise ValueError( @@ -477,8 +490,10 @@ def build_order_by(self, ident: str, source: "NodeSet") -> None: ) order_by.append(f"{ident}.{elm}") else: - path, prop = elm.rsplit("__", 1) - order_by_clause = self.lookup_query_variable(path) + path, prop = elm.rsplit("__" if not is_rel_property else "|", 1) + order_by_clause = self.lookup_query_variable( + path, return_relation=is_rel_property + ) order_by.append(f"{order_by_clause}.{prop}") self._ast.order_by = order_by @@ -521,6 +536,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: parts = re.split(path_split_regex, path) subgraph = self._ast.subgraph rel_iterator: str = "" + already_present = False + existing_rhs_name = "" for index, part in enumerate(parts): relationship = getattr(source_class_iterator, part) if rel_iterator: @@ -529,19 +546,6 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: # build source if "node_class" not in relationship.definition: relationship.lookup_node_class() - rhs_label = relationship.definition["node_class"].__label__ - rel_reference = f'{relationship.definition["node_class"]}_{part}' - self._node_counters[rel_reference] += 1 - if index + 1 == len(parts) and "alias" in relation: - # If an alias is defined, use it to store the last hop in the path - rhs_name = relation["alias"] - else: - rhs_name = ( - f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}" - ) - rhs_ident = f"{rhs_name}:{rhs_label}" - if relation["include_in_return"]: - self._additional_return(rhs_name) if not stmt: lhs_label = source_class_iterator.__label__ lhs_name = lhs_label.lower() @@ -559,15 +563,39 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: else: lhs_ident = stmt + already_present = part in subgraph rel_ident = self.create_ident() - if part not in self._ast.subgraph: + rhs_label = relationship.definition["node_class"].__label__ + if relation.get("relation_filtering"): + rhs_name = rel_ident + else: + rel_reference = f'{relationship.definition["node_class"]}_{part}' + self._node_counters[rel_reference] += 1 + if index + 1 == len(parts) and "alias" in relation: + # If an alias is defined, use it to store the last hop in the path + rhs_name = relation["alias"] + else: + rhs_name = f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}" + rhs_ident = f"{rhs_name}:{rhs_label}" + if relation["include_in_return"] and not already_present: + self._additional_return(rhs_name) + + if not already_present: subgraph[part] = { "target": relationship.definition["node_class"], "children": {}, "variable_name": rhs_name, "rel_variable_name": rel_ident, } - if relation["include_in_return"]: + else: + existing_rhs_name = subgraph[part][ + ( + "rel_variable_name" + if relation["relation_filtering"] + else "variable_name" + ) + ] + if relation["include_in_return"] and not already_present: self._additional_return(rel_ident) stmt = _rel_helper( lhs=lhs_ident, @@ -579,11 +607,14 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: source_class_iterator = relationship.definition["node_class"] subgraph = subgraph[part]["children"] - if relation.get("optional"): - self._ast.optional_match.append(stmt) - else: - self._ast.match.append(stmt) - return rhs_name + if not already_present: + if relation.get("optional"): + self._ast.optional_match.append(stmt) + else: + self._ast.match.append(stmt) + return rhs_name + + return existing_rhs_name def build_node(self, node): ident = node.__class__.__name__.lower() @@ -643,11 +674,21 @@ def _register_place_holder(self, key: str) -> str: return key + "_" + str(self._place_holder_registry[key]) def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str]: - path, prop = prop.rsplit("__", 1) - ident = self.build_traversal_from_path( - {"path": path, "include_in_return": True}, - source_class, - ) + is_rel_filter = "|" in prop + ident = self.lookup_query_variable(prop) + if is_rel_filter: + path, prop = prop.rsplit("|", 1) + else: + path, prop = prop.rsplit("__", 1) + if not ident: + ident = self.build_traversal_from_path( + { + "path": path, + "include_in_return": True, + "relation_filtering": is_rel_filter, + }, + source_class, + ) return ident, path, prop def _finalize_filter_statement( @@ -675,12 +716,14 @@ def _build_filter_statements( ) -> None: for prop, op_and_val in filters.items(): path = None - if "__" in prop: + is_rel_filter = "|" in prop + if "__" in prop or is_rel_filter: ident, path, prop = self._parse_path(source_class, prop) operator, val = op_and_val - prop = source_class.defined_properties(rels=False)[ - prop - ].get_db_property_name(prop) + if not is_rel_filter: + prop = source_class.defined_properties(rels=False)[ + prop + ].get_db_property_name(prop) statement = self._finalize_filter_statement(operator, ident, prop, val) target.append(statement) @@ -743,6 +786,7 @@ def lookup_query_variable( subgraph = self._ast.subgraph if not subgraph: return None + is_rel_property = "|" in path traversals = re.split(path_split_regex, path) if len(traversals) == 0: raise ValueError("Can only lookup traversal variables") @@ -757,7 +801,7 @@ def lookup_query_variable( elif part == last_property: # if last part of prop is the last traversal # we are safe to lookup the variable from the query - if return_relation: + if is_rel_property or return_relation: variable_to_return = f"{subgraph['rel_variable_name']}" else: variable_to_return = f"{subgraph['variable_name']}" @@ -899,7 +943,6 @@ def _execute(self, lazy: bool = False, dict_output: bool = False): for item in self._ast.additional_return ] query = self.build_query() - print(query) results, prop_names = db.cypher_query( query, self._query_params, resolve_objects=True ) diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 7b11b928..47d702cf 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -40,7 +40,7 @@ class SupplierRel(AsyncStructuredRel): class Supplier(AsyncStructuredNode): name = StringProperty() delivery_cost = IntegerProperty() - coffees = AsyncRelationshipTo("Coffee", "COFFEE SUPPLIERS") + coffees = AsyncRelationshipTo("Coffee", "COFFEE SUPPLIERS", model=SupplierRel) class Species(AsyncStructuredNode): @@ -594,6 +594,57 @@ async def test_filter_with_traversal(): assert results[0][0] == nescafe +@mark_async_test +async def test_relation_prop_filtering(): + # Clean DB before we start anything... + await adb.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = await Species(name="Arabica").save() + nescafe = await Coffee(name="Nescafe", price=99).save() + supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save() + supplier2 = await Supplier(name="Supplier 2", delivery_cost=20).save() + + await nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)}) + await nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)}) + await nescafe.species.connect(arabica) + + results = await Supplier.nodes.filter( + **{"coffees__name": "Nescafe", "coffees|since__gt": datetime(2018, 4, 1, 0, 0)} + ).all() + + assert len(results) == 1 + assert results[0][0] == supplier1 + + +@mark_async_test +async def test_relation_prop_ordering(): + # Clean DB before we start anything... + await adb.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = await Species(name="Arabica").save() + nescafe = await Coffee(name="Nescafe", price=99).save() + supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save() + supplier2 = await Supplier(name="Supplier 2", delivery_cost=20).save() + + await nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)}) + await nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)}) + await nescafe.species.connect(arabica) + + results = ( + await Supplier.nodes.fetch_relations("coffees").order_by("-coffees|since").all() + ) + assert len(results) == 2 + assert results[0][0] == supplier1 + assert results[1][0] == supplier2 + + results = ( + await Supplier.nodes.fetch_relations("coffees").order_by("coffees|since").all() + ) + assert len(results) == 2 + assert results[0][0] == supplier2 + assert results[1][0] == supplier1 + + @mark_async_test async def test_fetch_relations(): # Clean DB before we start anything... @@ -601,8 +652,8 @@ async def test_fetch_relations(): arabica = await Species(name="Arabica").save() robusta = await Species(name="Robusta").save() - nescafe = await Coffee(name="Nescafe 1000", price=99).save() - nescafe_gold = await Coffee(name="Nescafe 1001", price=11).save() + nescafe = await Coffee(name="Nescafe", price=99).save() + nescafe_gold = await Coffee(name="Nescafe Gold", price=11).save() tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() await nescafe.suppliers.connect(tesco) diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 39270124..832a5367 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -40,7 +40,7 @@ class SupplierRel(StructuredRel): class Supplier(StructuredNode): name = StringProperty() delivery_cost = IntegerProperty() - coffees = RelationshipTo("Coffee", "COFFEE SUPPLIERS") + coffees = RelationshipTo("Coffee", "COFFEE SUPPLIERS", model=SupplierRel) class Species(StructuredNode): @@ -588,6 +588,53 @@ def test_filter_with_traversal(): assert results[0][0] == nescafe +@mark_sync_test +def test_relation_prop_filtering(): + # Clean DB before we start anything... + db.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = Species(name="Arabica").save() + nescafe = Coffee(name="Nescafe", price=99).save() + supplier1 = Supplier(name="Supplier 1", delivery_cost=3).save() + supplier2 = Supplier(name="Supplier 2", delivery_cost=20).save() + + nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)}) + nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)}) + nescafe.species.connect(arabica) + + results = Supplier.nodes.filter( + **{"coffees__name": "Nescafe", "coffees|since__gt": datetime(2018, 4, 1, 0, 0)} + ).all() + + assert len(results) == 1 + assert results[0][0] == supplier1 + + +@mark_sync_test +def test_relation_prop_ordering(): + # Clean DB before we start anything... + db.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = Species(name="Arabica").save() + nescafe = Coffee(name="Nescafe", price=99).save() + supplier1 = Supplier(name="Supplier 1", delivery_cost=3).save() + supplier2 = Supplier(name="Supplier 2", delivery_cost=20).save() + + nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)}) + nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)}) + nescafe.species.connect(arabica) + + results = Supplier.nodes.fetch_relations("coffees").order_by("-coffees|since").all() + assert len(results) == 2 + assert results[0][0] == supplier1 + assert results[1][0] == supplier2 + + results = Supplier.nodes.fetch_relations("coffees").order_by("coffees|since").all() + assert len(results) == 2 + assert results[0][0] == supplier2 + assert results[1][0] == supplier1 + + @mark_sync_test def test_fetch_relations(): # Clean DB before we start anything... @@ -595,8 +642,8 @@ def test_fetch_relations(): arabica = Species(name="Arabica").save() robusta = Species(name="Robusta").save() - nescafe = Coffee(name="Nescafe 1000", price=99).save() - nescafe_gold = Coffee(name="Nescafe 1001", price=11).save() + nescafe = Coffee(name="Nescafe", price=99).save() + nescafe_gold = Coffee(name="Nescafe Gold", price=11).save() tesco = Supplier(name="Sainsburys", delivery_cost=3).save() nescafe.suppliers.connect(tesco) From 94afca2b1fed929003ddb7a9757aab2aaefe3178 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Fri, 18 Oct 2024 15:31:01 +0200 Subject: [PATCH 24/42] Various fixes --- neomodel/async_/match.py | 59 +++++++++++++++++++++++----------------- neomodel/sync_/match.py | 59 +++++++++++++++++++++++----------------- 2 files changed, 68 insertions(+), 50 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 80fb0ee9..d026f597 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -419,7 +419,6 @@ def __init__(self, node_set, subquery_context: bool = False) -> None: self._query_params: Dict = {} self._place_holder_registry: Dict = {} self._ident_count: int = 0 - self._node_counters = defaultdict(int) self._subquery_context: bool = subquery_context async def build_ast(self) -> "AsyncQueryBuilder": @@ -491,10 +490,11 @@ def build_order_by(self, ident: str, source: "AsyncNodeSet") -> None: order_by.append(f"{ident}.{elm}") else: path, prop = elm.rsplit("__" if not is_rel_property else "|", 1) - order_by_clause = self.lookup_query_variable( + result = self.lookup_query_variable( path, return_relation=is_rel_property ) - order_by.append(f"{order_by_clause}.{prop}") + if result: + order_by.append(f"{result[0]}.{prop}") self._ast.order_by = order_by async def build_traversal(self, traversal) -> str: @@ -529,7 +529,9 @@ def _additional_return(self, name: str): if name not in self._ast.additional_return and name != self._ast.return_clause: self._ast.additional_return.append(name) - def build_traversal_from_path(self, relation: dict, source_class) -> str: + def build_traversal_from_path( + self, relation: dict, source_class + ) -> Tuple[str, Any]: path: str = relation["path"] stmt: str = "" source_class_iterator = source_class @@ -570,12 +572,11 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: rhs_name = rel_ident else: rel_reference = f'{relationship.definition["node_class"]}_{part}' - self._node_counters[rel_reference] += 1 if index + 1 == len(parts) and "alias" in relation: # If an alias is defined, use it to store the last hop in the path rhs_name = relation["alias"] else: - rhs_name = f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}" + rhs_name = f"{rhs_label.lower()}_{rel_iterator}" rhs_ident = f"{rhs_name}:{rhs_label}" if relation["include_in_return"] and not already_present: self._additional_return(rhs_name) @@ -590,7 +591,7 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: else: existing_rhs_name = subgraph[part][ "rel_variable_name" - if relation["relation_filtering"] + if relation.get("relation_filtering") else "variable_name" ] if relation["include_in_return"] and not already_present: @@ -610,9 +611,9 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: self._ast.optional_match.append(stmt) else: self._ast.match.append(stmt) - return rhs_name + return rhs_name, relationship.definition["node_class"] - return existing_rhs_name + return existing_rhs_name, relationship.definition["node_class"] async def build_node(self, node): ident = node.__class__.__name__.lower() @@ -671,15 +672,15 @@ def _register_place_holder(self, key: str) -> str: self._place_holder_registry[key] = 1 return key + "_" + str(self._place_holder_registry[key]) - def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str]: + def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str, Any]: is_rel_filter = "|" in prop - ident = self.lookup_query_variable(prop) if is_rel_filter: path, prop = prop.rsplit("|", 1) else: path, prop = prop.rsplit("__", 1) - if not ident: - ident = self.build_traversal_from_path( + result = self.lookup_query_variable(path, return_relation=is_rel_filter) + if not result: + ident, target_class = self.build_traversal_from_path( { "path": path, "include_in_return": True, @@ -687,7 +688,9 @@ def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str]: }, source_class, ) - return ident, path, prop + else: + ident, target_class = result + return ident, path, prop, target_class def _finalize_filter_statement( self, operator: str, ident: str, prop: str, val: Any @@ -715,11 +718,12 @@ def _build_filter_statements( for prop, op_and_val in filters.items(): path = None is_rel_filter = "|" in prop + target_class = source_class if "__" in prop or is_rel_filter: - ident, path, prop = self._parse_path(source_class, prop) + ident, path, prop, target_class = self._parse_path(source_class, prop) operator, val = op_and_val if not is_rel_filter: - prop = source_class.defined_properties(rels=False)[ + prop = target_class.defined_properties(rels=False)[ prop ].get_db_property_name(prop) statement = self._finalize_filter_statement(operator, ident, prop, val) @@ -779,19 +783,18 @@ def build_where_stmt( def lookup_query_variable( self, path: str, return_relation: bool = False - ) -> TOptional[str]: + ) -> TOptional[Tuple[str, Any]]: """Retrieve the variable name generated internally for the given traversal path.""" subgraph = self._ast.subgraph if not subgraph: return None - is_rel_property = "|" in path traversals = re.split(path_split_regex, path) if len(traversals) == 0: raise ValueError("Can only lookup traversal variables") if traversals[0] not in subgraph: return None subgraph = subgraph[traversals[0]] - variable_to_return = None + variable_to_return = "" last_property = traversals[-1] for part in traversals: if part in subgraph["children"]: @@ -799,13 +802,13 @@ def lookup_query_variable( elif part == last_property: # if last part of prop is the last traversal # we are safe to lookup the variable from the query - if is_rel_property or return_relation: + if return_relation: variable_to_return = f"{subgraph['rel_variable_name']}" else: variable_to_return = f"{subgraph['variable_name']}" else: - break - return variable_to_return + return None + return variable_to_return, subgraph["target"] def build_query(self) -> str: query: str = "" @@ -827,6 +830,9 @@ def build_query(self) -> str: query += " OPTIONAL MATCH ".join(i for i in self._ast.optional_match) if self._ast.where: + if self._ast.optional_match: + # Make sure filtering works as expected with optional match, even if it's not performant... + query += " WITH *" query += " WHERE " query += " AND ".join(self._ast.where) @@ -842,20 +848,23 @@ def build_query(self) -> str: if type(source) is str: injected_vars.append(f"{source} AS {name}") elif isinstance(source, RelationNameResolver): - internal_name = self.lookup_query_variable( + result = self.lookup_query_variable( source.relation, return_relation=True ) - if not internal_name: + if not result: raise ValueError( f"Unable to resolve variable name for relation {source.relation}." ) - injected_vars.append(f"{internal_name} AS {name}") + injected_vars.append(f"{result[0]} AS {name}") query += ",".join(injected_vars) if not transform["ordering"]: continue query += " ORDER BY " ordering: list = [] for item in transform["ordering"]: + if isinstance(item, RawCypher): + ordering.append(item.render({})) + continue if item.startswith("-"): ordering.append(f"{item[1:]} DESC") else: diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index cf344b5a..b41e0877 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -419,7 +419,6 @@ def __init__(self, node_set, subquery_context: bool = False) -> None: self._query_params: Dict = {} self._place_holder_registry: Dict = {} self._ident_count: int = 0 - self._node_counters = defaultdict(int) self._subquery_context: bool = subquery_context def build_ast(self) -> "QueryBuilder": @@ -491,10 +490,11 @@ def build_order_by(self, ident: str, source: "NodeSet") -> None: order_by.append(f"{ident}.{elm}") else: path, prop = elm.rsplit("__" if not is_rel_property else "|", 1) - order_by_clause = self.lookup_query_variable( + result = self.lookup_query_variable( path, return_relation=is_rel_property ) - order_by.append(f"{order_by_clause}.{prop}") + if result: + order_by.append(f"{result[0]}.{prop}") self._ast.order_by = order_by def build_traversal(self, traversal) -> str: @@ -529,7 +529,9 @@ def _additional_return(self, name: str): if name not in self._ast.additional_return and name != self._ast.return_clause: self._ast.additional_return.append(name) - def build_traversal_from_path(self, relation: dict, source_class) -> str: + def build_traversal_from_path( + self, relation: dict, source_class + ) -> Tuple[str, Any]: path: str = relation["path"] stmt: str = "" source_class_iterator = source_class @@ -570,12 +572,11 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: rhs_name = rel_ident else: rel_reference = f'{relationship.definition["node_class"]}_{part}' - self._node_counters[rel_reference] += 1 if index + 1 == len(parts) and "alias" in relation: # If an alias is defined, use it to store the last hop in the path rhs_name = relation["alias"] else: - rhs_name = f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}" + rhs_name = f"{rhs_label.lower()}_{rel_iterator}" rhs_ident = f"{rhs_name}:{rhs_label}" if relation["include_in_return"] and not already_present: self._additional_return(rhs_name) @@ -591,7 +592,7 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: existing_rhs_name = subgraph[part][ ( "rel_variable_name" - if relation["relation_filtering"] + if relation.get("relation_filtering") else "variable_name" ) ] @@ -612,9 +613,9 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: self._ast.optional_match.append(stmt) else: self._ast.match.append(stmt) - return rhs_name + return rhs_name, relationship.definition["node_class"] - return existing_rhs_name + return existing_rhs_name, relationship.definition["node_class"] def build_node(self, node): ident = node.__class__.__name__.lower() @@ -673,15 +674,15 @@ def _register_place_holder(self, key: str) -> str: self._place_holder_registry[key] = 1 return key + "_" + str(self._place_holder_registry[key]) - def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str]: + def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str, Any]: is_rel_filter = "|" in prop - ident = self.lookup_query_variable(prop) if is_rel_filter: path, prop = prop.rsplit("|", 1) else: path, prop = prop.rsplit("__", 1) - if not ident: - ident = self.build_traversal_from_path( + result = self.lookup_query_variable(path, return_relation=is_rel_filter) + if not result: + ident, target_class = self.build_traversal_from_path( { "path": path, "include_in_return": True, @@ -689,7 +690,9 @@ def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str]: }, source_class, ) - return ident, path, prop + else: + ident, target_class = result + return ident, path, prop, target_class def _finalize_filter_statement( self, operator: str, ident: str, prop: str, val: Any @@ -717,11 +720,12 @@ def _build_filter_statements( for prop, op_and_val in filters.items(): path = None is_rel_filter = "|" in prop + target_class = source_class if "__" in prop or is_rel_filter: - ident, path, prop = self._parse_path(source_class, prop) + ident, path, prop, target_class = self._parse_path(source_class, prop) operator, val = op_and_val if not is_rel_filter: - prop = source_class.defined_properties(rels=False)[ + prop = target_class.defined_properties(rels=False)[ prop ].get_db_property_name(prop) statement = self._finalize_filter_statement(operator, ident, prop, val) @@ -781,19 +785,18 @@ def build_where_stmt( def lookup_query_variable( self, path: str, return_relation: bool = False - ) -> TOptional[str]: + ) -> TOptional[Tuple[str, Any]]: """Retrieve the variable name generated internally for the given traversal path.""" subgraph = self._ast.subgraph if not subgraph: return None - is_rel_property = "|" in path traversals = re.split(path_split_regex, path) if len(traversals) == 0: raise ValueError("Can only lookup traversal variables") if traversals[0] not in subgraph: return None subgraph = subgraph[traversals[0]] - variable_to_return = None + variable_to_return = "" last_property = traversals[-1] for part in traversals: if part in subgraph["children"]: @@ -801,13 +804,13 @@ def lookup_query_variable( elif part == last_property: # if last part of prop is the last traversal # we are safe to lookup the variable from the query - if is_rel_property or return_relation: + if return_relation: variable_to_return = f"{subgraph['rel_variable_name']}" else: variable_to_return = f"{subgraph['variable_name']}" else: - break - return variable_to_return + return None + return variable_to_return, subgraph["target"] def build_query(self) -> str: query: str = "" @@ -829,6 +832,9 @@ def build_query(self) -> str: query += " OPTIONAL MATCH ".join(i for i in self._ast.optional_match) if self._ast.where: + if self._ast.optional_match: + # Make sure filtering works as expected with optional match, even if it's not performant... + query += " WITH *" query += " WHERE " query += " AND ".join(self._ast.where) @@ -844,20 +850,23 @@ def build_query(self) -> str: if type(source) is str: injected_vars.append(f"{source} AS {name}") elif isinstance(source, RelationNameResolver): - internal_name = self.lookup_query_variable( + result = self.lookup_query_variable( source.relation, return_relation=True ) - if not internal_name: + if not result: raise ValueError( f"Unable to resolve variable name for relation {source.relation}." ) - injected_vars.append(f"{internal_name} AS {name}") + injected_vars.append(f"{result[0]} AS {name}") query += ",".join(injected_vars) if not transform["ordering"]: continue query += " ORDER BY " ordering: list = [] for item in transform["ordering"]: + if isinstance(item, RawCypher): + ordering.append(item.render({})) + continue if item.startswith("-"): ordering.append(f"{item[1:]} DESC") else: From 29f8bfa58e931a7253647ba8d0e3aa78079a5bbd Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Tue, 22 Oct 2024 14:34:00 +0200 Subject: [PATCH 25/42] Improvements and fixes --- neomodel/async_/match.py | 176 +++++++++++++++++++++++++--------- neomodel/sync_/match.py | 176 +++++++++++++++++++++++++--------- test/async_/test_match_api.py | 7 +- test/sync_/test_issue283.py | 1 + test/sync_/test_match_api.py | 7 +- 5 files changed, 267 insertions(+), 100 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index d026f597..b34938c6 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -1,7 +1,6 @@ import inspect import re import string -from collections import defaultdict from dataclasses import dataclass from typing import Any, Dict, List from typing import Optional as TOptional @@ -571,7 +570,6 @@ def build_traversal_from_path( if relation.get("relation_filtering"): rhs_name = rel_ident else: - rel_reference = f'{relationship.definition["node_class"]}_{part}' if index + 1 == len(parts) and "alias" in relation: # If an alias is defined, use it to store the last hop in the path rhs_name = relation["alias"] @@ -794,20 +792,20 @@ def lookup_query_variable( if traversals[0] not in subgraph: return None subgraph = subgraph[traversals[0]] + if len(traversals) == 1: + variable_to_return = f"{subgraph['rel_variable_name' if return_relation else 'variable_name']}" + return variable_to_return, subgraph["target"] variable_to_return = "" last_property = traversals[-1] - for part in traversals: - if part in subgraph["children"]: - subgraph = subgraph["children"][part] - elif part == last_property: + for part in traversals[1:]: + child = subgraph["children"].get(part) + if not child: + return None + subgraph = child + if part == last_property: # if last part of prop is the last traversal # we are safe to lookup the variable from the query - if return_relation: - variable_to_return = f"{subgraph['rel_variable_name']}" - else: - variable_to_return = f"{subgraph['variable_name']}" - else: - return None + variable_to_return = f"{subgraph['rel_variable_name' if return_relation else 'variable_name']}" return variable_to_return, subgraph["target"] def build_query(self) -> str: @@ -844,6 +842,9 @@ def build_query(self) -> str: for transform in self.node_set._intermediate_transforms: query += " WITH " injected_vars: list = [] + # Reset return list since we'll probably invalidate most variables + self._ast.return_clause = "" + self._ast.additional_return = [] for name, source in transform["vars"].items(): if type(source) is str: injected_vars.append(f"{source} AS {name}") @@ -856,6 +857,13 @@ def build_query(self) -> str: f"Unable to resolve variable name for relation {source.relation}." ) injected_vars.append(f"{result[0]} AS {name}") + elif isinstance(source, NodeNameResolver): + result = self.lookup_query_variable(source.node) + if not result: + raise ValueError( + f"Unable to resolve variable name for node {source.node}." + ) + injected_vars.append(f"{result[0]} AS {name}") query += ",".join(injected_vars) if not transform["ordering"]: continue @@ -876,6 +884,17 @@ def build_query(self) -> str: for subquery, return_set in self.node_set._subqueries: outer_primary_var = self._ast.return_clause query += f" CALL {{ WITH {outer_primary_var} {subquery} }} " + for varname in return_set: + # We declare the returned variables as "virtual" relations of the + # root node class to make sure they will be translated by a call to + # resolve_subgraph() (otherwise, they will be lost). + # This is probably a temporary solution until we find something better... + self._ast.subgraph[varname] = { + "target": None, # We don't need target class in this use case + "children": {}, + "variable_name": varname, + "rel_variable_name": varname, + } returned_items += return_set query += " RETURN " @@ -884,12 +903,18 @@ def build_query(self) -> str: if self._ast.additional_return: returned_items += self._ast.additional_return if hasattr(self.node_set, "_extra_results"): - for varname, vardef in self.node_set._extra_results.items(): + for props in self.node_set._extra_results: + leftpart = props["vardef"].render(self) + varname = ( + props["alias"] + if props.get("alias") + else props["vardef"].get_internal_name() + ) if varname in returned_items: # We're about to override an existing variable, delete it first to # avoid duplicate error returned_items.remove(varname) - returned_items.append(f"{str(vardef)} AS {varname}") + returned_items.append(f"{leftpart} AS {varname}") query += ", ".join(returned_items) @@ -1062,10 +1087,62 @@ class Optional: @dataclass -class AggregatingFunction: +class RelationNameResolver: + """Helper to refer to a relation variable name. + + Since variable names are generated automatically within MATCH statements (for + anything injected using fetch_relations or traverse_relations), we need a way to + retrieve them. + + """ + + relation: str + + +@dataclass +class NodeNameResolver: + """Helper to refer to a node variable name. + + Since variable names are generated automatically within MATCH statements (for + anything injected using fetch_relations or traverse_relations), we need a way to + retrieve them. + + """ + + node: str + + +@dataclass +class BaseFunction: + input_name: Union[str, "BaseFunction", NodeNameResolver, RelationNameResolver] + + def __post_init__(self) -> None: + self._internal_name: str = "" + + def get_internal_name(self) -> str: + return self._internal_name + + def resolve_internal_name(self, qbuilder: AsyncQueryBuilder) -> str: + if isinstance(self.input_name, NodeNameResolver): + result = qbuilder.lookup_query_variable(self.input_name.node) + elif isinstance(self.input_name, RelationNameResolver): + result = qbuilder.lookup_query_variable(self.input_name.relation, True) + else: + result = (str(self.input_name), None) + if result is None: + raise ValueError(f"Unknown variable {self.input_name} used in Collect()") + self._internal_name = result[0] + return self._internal_name + + def render(self, qbuilder: AsyncQueryBuilder) -> str: + raise NotImplementedError + + +@dataclass +class AggregatingFunction(BaseFunction): """Base aggregating function class.""" - input_name: str + pass @dataclass @@ -1074,38 +1151,33 @@ class Collect(AggregatingFunction): distinct: bool = False - def __str__(self): + def render(self, qbuilder: AsyncQueryBuilder) -> str: + varname = self.resolve_internal_name(qbuilder) if self.distinct: - return f"collect(DISTINCT {self.input_name})" - return f"collect({self.input_name})" + return f"collect(DISTINCT {varname})" + return f"collect({varname})" @dataclass -class ScalarFunction: +class ScalarFunction(BaseFunction): """Base scalar function class.""" - input_name: Union[str, AggregatingFunction] + pass @dataclass class Last(ScalarFunction): """last() function.""" - def __str__(self) -> str: - return f"last({str(self.input_name)})" - - -@dataclass -class RelationNameResolver: - """Helper to refer to a relation variable name. - - Since variable names are generated automatically within MATCH statements (for - anything injected using fetch_relations or traverse_relations), we need a way to - retrieve them. - - """ - - relation: str + def render(self, qbuilder: AsyncQueryBuilder) -> str: + if isinstance(self.input_name, str): + content = str(self.input_name) + elif isinstance(self.input_name, BaseFunction): + content = self.input_name.render(qbuilder) + self._internal_name = self.input_name.get_internal_name() + else: + content = self.resolve_internal_name(qbuilder) + return f"last({content})" @dataclass @@ -1156,7 +1228,7 @@ def __init__(self, source) -> None: self.dont_match: Dict = {} self.relations_to_fetch: List = [] - self._extra_results: dict = {} + self._extra_results: List = [] self._subqueries: list[Tuple[str, list[str]]] = [] self._intermediate_transforms: list = [] @@ -1357,7 +1429,9 @@ def annotate(self, *vars, **aliased_vars): def register_extra_var(vardef, varname: Union[str, None] = None): if isinstance(vardef, (AggregatingFunction, ScalarFunction)): - self._extra_results[varname if varname else vardef.input_name] = vardef + self._extra_results.append( + {"vardef": vardef, "alias": varname if varname else ""} + ) else: raise NotImplementedError @@ -1411,17 +1485,20 @@ async def resolve_subgraph(self) -> list: we use a dedicated property to store node's relations. """ - if not self.relations_to_fetch: - raise RuntimeError( - "Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()." - ) - if not self.relations_to_fetch[0]["include_in_return"]: + if ( + self.relations_to_fetch + and not self.relations_to_fetch[0]["include_in_return"] + ): raise NotImplementedError( "You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead." ) results: list = [] qbuilder = self.query_cls(self) await qbuilder.build_ast() + if not qbuilder._ast.subgraph: + raise RuntimeError( + "Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()." + ) all_nodes = qbuilder._execute(dict_output=True) other_nodes = {} root_node = None @@ -1454,7 +1531,8 @@ async def subquery( if ( var != qbuilder._ast.return_clause and var not in qbuilder._ast.additional_return - and var not in nodeset._extra_results + and var + not in [res["alias"] for res in nodeset._extra_results if res["alias"]] ): raise RuntimeError(f"Variable '{var}' is not returned by subquery.") self._subqueries.append((qbuilder.build_query(), return_set)) @@ -1463,10 +1541,16 @@ async def subquery( def intermediate_transform( self, vars: Dict[str, Any], ordering: TOptional[list] = None ) -> "AsyncNodeSet": + if not vars: + raise ValueError( + "You must provide one variable at least when calling intermediate_transform()" + ) for name, source in vars.items(): - if type(source) is not str and not isinstance(source, RelationNameResolver): + if type(source) is not str and not isinstance( + source, (NodeNameResolver, RelationNameResolver) + ): raise ValueError( - f"Wrong source type specified for variable '{name}', should be a string or an instance of RelationNameResolver" + f"Wrong source type specified for variable '{name}', should be a string or an instance of NodeNameResolver or RelationNameResolver" ) self._intermediate_transforms.append({"vars": vars, "ordering": ordering}) return self diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index b41e0877..73715cc8 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -1,7 +1,6 @@ import inspect import re import string -from collections import defaultdict from dataclasses import dataclass from typing import Any, Dict, List from typing import Optional as TOptional @@ -571,7 +570,6 @@ def build_traversal_from_path( if relation.get("relation_filtering"): rhs_name = rel_ident else: - rel_reference = f'{relationship.definition["node_class"]}_{part}' if index + 1 == len(parts) and "alias" in relation: # If an alias is defined, use it to store the last hop in the path rhs_name = relation["alias"] @@ -796,20 +794,20 @@ def lookup_query_variable( if traversals[0] not in subgraph: return None subgraph = subgraph[traversals[0]] + if len(traversals) == 1: + variable_to_return = f"{subgraph['rel_variable_name' if return_relation else 'variable_name']}" + return variable_to_return, subgraph["target"] variable_to_return = "" last_property = traversals[-1] - for part in traversals: - if part in subgraph["children"]: - subgraph = subgraph["children"][part] - elif part == last_property: + for part in traversals[1:]: + child = subgraph["children"].get(part) + if not child: + return None + subgraph = child + if part == last_property: # if last part of prop is the last traversal # we are safe to lookup the variable from the query - if return_relation: - variable_to_return = f"{subgraph['rel_variable_name']}" - else: - variable_to_return = f"{subgraph['variable_name']}" - else: - return None + variable_to_return = f"{subgraph['rel_variable_name' if return_relation else 'variable_name']}" return variable_to_return, subgraph["target"] def build_query(self) -> str: @@ -846,6 +844,9 @@ def build_query(self) -> str: for transform in self.node_set._intermediate_transforms: query += " WITH " injected_vars: list = [] + # Reset return list since we'll probably invalidate most variables + self._ast.return_clause = "" + self._ast.additional_return = [] for name, source in transform["vars"].items(): if type(source) is str: injected_vars.append(f"{source} AS {name}") @@ -858,6 +859,13 @@ def build_query(self) -> str: f"Unable to resolve variable name for relation {source.relation}." ) injected_vars.append(f"{result[0]} AS {name}") + elif isinstance(source, NodeNameResolver): + result = self.lookup_query_variable(source.node) + if not result: + raise ValueError( + f"Unable to resolve variable name for node {source.node}." + ) + injected_vars.append(f"{result[0]} AS {name}") query += ",".join(injected_vars) if not transform["ordering"]: continue @@ -878,6 +886,17 @@ def build_query(self) -> str: for subquery, return_set in self.node_set._subqueries: outer_primary_var = self._ast.return_clause query += f" CALL {{ WITH {outer_primary_var} {subquery} }} " + for varname in return_set: + # We declare the returned variables as "virtual" relations of the + # root node class to make sure they will be translated by a call to + # resolve_subgraph() (otherwise, they will be lost). + # This is probably a temporary solution until we find something better... + self._ast.subgraph[varname] = { + "target": None, # We don't need target class in this use case + "children": {}, + "variable_name": varname, + "rel_variable_name": varname, + } returned_items += return_set query += " RETURN " @@ -886,12 +905,18 @@ def build_query(self) -> str: if self._ast.additional_return: returned_items += self._ast.additional_return if hasattr(self.node_set, "_extra_results"): - for varname, vardef in self.node_set._extra_results.items(): + for props in self.node_set._extra_results: + leftpart = props["vardef"].render(self) + varname = ( + props["alias"] + if props.get("alias") + else props["vardef"].get_internal_name() + ) if varname in returned_items: # We're about to override an existing variable, delete it first to # avoid duplicate error returned_items.remove(varname) - returned_items.append(f"{str(vardef)} AS {varname}") + returned_items.append(f"{leftpart} AS {varname}") query += ", ".join(returned_items) @@ -1062,10 +1087,62 @@ class Optional: @dataclass -class AggregatingFunction: +class RelationNameResolver: + """Helper to refer to a relation variable name. + + Since variable names are generated automatically within MATCH statements (for + anything injected using fetch_relations or traverse_relations), we need a way to + retrieve them. + + """ + + relation: str + + +@dataclass +class NodeNameResolver: + """Helper to refer to a node variable name. + + Since variable names are generated automatically within MATCH statements (for + anything injected using fetch_relations or traverse_relations), we need a way to + retrieve them. + + """ + + node: str + + +@dataclass +class BaseFunction: + input_name: Union[str, "BaseFunction", NodeNameResolver, RelationNameResolver] + + def __post_init__(self) -> None: + self._internal_name: str = "" + + def get_internal_name(self) -> str: + return self._internal_name + + def resolve_internal_name(self, qbuilder: QueryBuilder) -> str: + if isinstance(self.input_name, NodeNameResolver): + result = qbuilder.lookup_query_variable(self.input_name.node) + elif isinstance(self.input_name, RelationNameResolver): + result = qbuilder.lookup_query_variable(self.input_name.relation, True) + else: + result = (str(self.input_name), None) + if result is None: + raise ValueError(f"Unknown variable {self.input_name} used in Collect()") + self._internal_name = result[0] + return self._internal_name + + def render(self, qbuilder: QueryBuilder) -> str: + raise NotImplementedError + + +@dataclass +class AggregatingFunction(BaseFunction): """Base aggregating function class.""" - input_name: str + pass @dataclass @@ -1074,38 +1151,33 @@ class Collect(AggregatingFunction): distinct: bool = False - def __str__(self): + def render(self, qbuilder: QueryBuilder) -> str: + varname = self.resolve_internal_name(qbuilder) if self.distinct: - return f"collect(DISTINCT {self.input_name})" - return f"collect({self.input_name})" + return f"collect(DISTINCT {varname})" + return f"collect({varname})" @dataclass -class ScalarFunction: +class ScalarFunction(BaseFunction): """Base scalar function class.""" - input_name: Union[str, AggregatingFunction] + pass @dataclass class Last(ScalarFunction): """last() function.""" - def __str__(self) -> str: - return f"last({str(self.input_name)})" - - -@dataclass -class RelationNameResolver: - """Helper to refer to a relation variable name. - - Since variable names are generated automatically within MATCH statements (for - anything injected using fetch_relations or traverse_relations), we need a way to - retrieve them. - - """ - - relation: str + def render(self, qbuilder: QueryBuilder) -> str: + if isinstance(self.input_name, str): + content = str(self.input_name) + elif isinstance(self.input_name, BaseFunction): + content = self.input_name.render(qbuilder) + self._internal_name = self.input_name.get_internal_name() + else: + content = self.resolve_internal_name(qbuilder) + return f"last({content})" @dataclass @@ -1156,7 +1228,7 @@ def __init__(self, source) -> None: self.dont_match: Dict = {} self.relations_to_fetch: List = [] - self._extra_results: dict = {} + self._extra_results: List = [] self._subqueries: list[Tuple[str, list[str]]] = [] self._intermediate_transforms: list = [] @@ -1357,7 +1429,9 @@ def annotate(self, *vars, **aliased_vars): def register_extra_var(vardef, varname: Union[str, None] = None): if isinstance(vardef, (AggregatingFunction, ScalarFunction)): - self._extra_results[varname if varname else vardef.input_name] = vardef + self._extra_results.append( + {"vardef": vardef, "alias": varname if varname else ""} + ) else: raise NotImplementedError @@ -1411,17 +1485,20 @@ def resolve_subgraph(self) -> list: we use a dedicated property to store node's relations. """ - if not self.relations_to_fetch: - raise RuntimeError( - "Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()." - ) - if not self.relations_to_fetch[0]["include_in_return"]: + if ( + self.relations_to_fetch + and not self.relations_to_fetch[0]["include_in_return"] + ): raise NotImplementedError( "You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead." ) results: list = [] qbuilder = self.query_cls(self) qbuilder.build_ast() + if not qbuilder._ast.subgraph: + raise RuntimeError( + "Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()." + ) all_nodes = qbuilder._execute(dict_output=True) other_nodes = {} root_node = None @@ -1452,7 +1529,8 @@ def subquery(self, nodeset: "NodeSet", return_set: List[str]) -> "NodeSet": if ( var != qbuilder._ast.return_clause and var not in qbuilder._ast.additional_return - and var not in nodeset._extra_results + and var + not in [res["alias"] for res in nodeset._extra_results if res["alias"]] ): raise RuntimeError(f"Variable '{var}' is not returned by subquery.") self._subqueries.append((qbuilder.build_query(), return_set)) @@ -1461,10 +1539,16 @@ def subquery(self, nodeset: "NodeSet", return_set: List[str]) -> "NodeSet": def intermediate_transform( self, vars: Dict[str, Any], ordering: TOptional[list] = None ) -> "NodeSet": + if not vars: + raise ValueError( + "You must provide one variable at least when calling intermediate_transform()" + ) for name, source in vars.items(): - if type(source) is not str and not isinstance(source, RelationNameResolver): + if type(source) is not str and not isinstance( + source, (NodeNameResolver, RelationNameResolver) + ): raise ValueError( - f"Wrong source type specified for variable '{name}', should be a string or an instance of RelationNameResolver" + f"Wrong source type specified for variable '{name}', should be a string or an instance of NodeNameResolver or RelationNameResolver" ) self._intermediate_transforms.append({"vars": vars, "ordering": ordering}) return self diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 47d702cf..a69b9834 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -677,7 +677,7 @@ async def test_fetch_relations(): .fetch_relations(Optional("coffees__suppliers")) .all() ) - assert result[0][0] is None + assert len(result) == 0 if AsyncUtil.is_async_code: count = ( @@ -914,13 +914,12 @@ async def test_intermediate_transform(): ) assert len(result) == 1 - assert len(result[0]) == 2 - assert result[0][1] == supplier2 + assert result[0] == supplier2 with raises( ValueError, match=re.escape( - r"Wrong source type specified for variable 'test', should be a string or an instance of RelationNameResolver" + r"Wrong source type specified for variable 'test', should be a string or an instance of NodeNameResolver or RelationNameResolver" ), ): Coffee.nodes.traverse_relations(suppliers="suppliers").intermediate_transform( diff --git a/test/sync_/test_issue283.py b/test/sync_/test_issue283.py index a059f7f2..611431ce 100644 --- a/test/sync_/test_issue283.py +++ b/test/sync_/test_issue283.py @@ -9,6 +9,7 @@ idea remains the same: "Instantiate the correct type of node at the end of a relationship as specified by the model" """ + import random from test._async_compat import mark_sync_test diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 832a5367..ff421a4b 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -667,7 +667,7 @@ def test_fetch_relations(): .fetch_relations(Optional("coffees__suppliers")) .all() ) - assert result[0][0] is None + assert len(result) == 0 if Util.is_async_code: count = ( @@ -902,13 +902,12 @@ def test_intermediate_transform(): ) assert len(result) == 1 - assert len(result[0]) == 2 - assert result[0][1] == supplier2 + assert result[0] == supplier2 with raises( ValueError, match=re.escape( - r"Wrong source type specified for variable 'test', should be a string or an instance of RelationNameResolver" + r"Wrong source type specified for variable 'test', should be a string or an instance of NodeNameResolver or RelationNameResolver" ), ): Coffee.nodes.traverse_relations(suppliers="suppliers").intermediate_transform( From 212b3499b63867d53782248cc5f55d20ee6fd207 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Tue, 22 Oct 2024 16:20:11 +0200 Subject: [PATCH 26/42] Improved tests --- test/async_/test_match_api.py | 27 +++++++++++++++++++++++++-- test/sync_/test_match_api.py | 27 +++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index a69b9834..81944871 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -25,6 +25,7 @@ AsyncTraversal, Collect, Last, + NodeNameResolver, Optional, RawCypher, RelationNameResolver, @@ -768,6 +769,19 @@ async def test_annotate_and_collect(): ) assert len(result[0][1][0]) == 2 # 2 species must be there + result = ( + await Supplier.nodes.traverse_relations("coffees__species") + .annotate( + all_species=Collect(NodeNameResolver("coffees__species"), distinct=True), + all_species_rels=Collect( + RelationNameResolver("coffees__species"), distinct=True + ), + ) + .all() + ) + assert len(result[0][1][0]) == 2 # 2 species must be there + assert len(result[0][2][0]) == 3 # 3 species relations must be there + @mark_async_test async def test_resolve_subgraph(): @@ -900,11 +914,11 @@ async def test_intermediate_transform(): await nescafe.species.connect(arabica) result = ( - await Coffee.nodes.traverse_relations(suppliers="suppliers") + await Coffee.nodes.fetch_relations("suppliers") .intermediate_transform( { "coffee": "coffee", - "suppliers": "suppliers", + "suppliers": NodeNameResolver("suppliers"), "r": RelationNameResolver("suppliers"), }, ordering=["-r.since"], @@ -927,6 +941,15 @@ async def test_intermediate_transform(): "test": Collect("suppliers"), } ) + with raises( + ValueError, + match=re.escape( + r"You must provide one variable at least when calling intermediate_transform()" + ), + ): + Coffee.nodes.traverse_relations(suppliers="suppliers").intermediate_transform( + {} + ) @mark_async_test diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index ff421a4b..e47e3396 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -23,6 +23,7 @@ from neomodel.sync_.match import ( Collect, Last, + NodeNameResolver, NodeSet, Optional, QueryBuilder, @@ -756,6 +757,19 @@ def test_annotate_and_collect(): ) assert len(result[0][1][0]) == 2 # 2 species must be there + result = ( + Supplier.nodes.traverse_relations("coffees__species") + .annotate( + all_species=Collect(NodeNameResolver("coffees__species"), distinct=True), + all_species_rels=Collect( + RelationNameResolver("coffees__species"), distinct=True + ), + ) + .all() + ) + assert len(result[0][1][0]) == 2 # 2 species must be there + assert len(result[0][2][0]) == 3 # 3 species relations must be there + @mark_sync_test def test_resolve_subgraph(): @@ -888,11 +902,11 @@ def test_intermediate_transform(): nescafe.species.connect(arabica) result = ( - Coffee.nodes.traverse_relations(suppliers="suppliers") + Coffee.nodes.fetch_relations("suppliers") .intermediate_transform( { "coffee": "coffee", - "suppliers": "suppliers", + "suppliers": NodeNameResolver("suppliers"), "r": RelationNameResolver("suppliers"), }, ordering=["-r.since"], @@ -915,6 +929,15 @@ def test_intermediate_transform(): "test": Collect("suppliers"), } ) + with raises( + ValueError, + match=re.escape( + r"You must provide one variable at least when calling intermediate_transform()" + ), + ): + Coffee.nodes.traverse_relations(suppliers="suppliers").intermediate_transform( + {} + ) @mark_sync_test From f66c6530b5abd782c890c8293654740beb2bf459 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 25 Oct 2024 14:47:53 +0200 Subject: [PATCH 27/42] Add section for filtering, ordering and traversal --- doc/source/filtering_ordering.rst | 197 ++++++++++++++++++++++++++++++ doc/source/traversal.rst | 88 +++++++++++++ 2 files changed, 285 insertions(+) create mode 100644 doc/source/filtering_ordering.rst create mode 100644 doc/source/traversal.rst diff --git a/doc/source/filtering_ordering.rst b/doc/source/filtering_ordering.rst new file mode 100644 index 00000000..3dee395f --- /dev/null +++ b/doc/source/filtering_ordering.rst @@ -0,0 +1,197 @@ +====================== +Filtering and ordering +====================== + +For the examples in this section, we will be using the following model:: + + class SupplierRel(StructuredRel): + since = DateTimeProperty(default=datetime.now) + + + class Supplier(StructuredNode): + name = StringProperty() + delivery_cost = IntegerProperty() + + + class Coffee(StructuredNode): + name = StringProperty(unique_index=True) + price = IntegerProperty() + suppliers = RelationshipFrom(Supplier, 'SUPPLIES', model=SupplierRel) + +Filtering +========= + +neomodel allows filtering on nodes' and relationships' properties. Filters can be combined using Django's Q syntax. It also allows multi-hop relationship traversals to filter on "remote" elements. + +Filter methods +-------------- + +The ``.nodes`` property of a class returns all nodes of that type from the database. + +This set (called `NodeSet`) can be iterated over and filtered on, using the `.filter` method:: + + # nodes with label Coffee whose price is greater than 2 + high_end_coffees = Coffee.nodes.filter(price__gt=2) + + try: + java = Coffee.nodes.get(name='Java') + except DoesNotExist: + # .filter will not throw an exception if no results are found + # but .get will + print("Couldn't find coffee 'Java'") + +The filter method borrows the same Django filter format with double underscore prefixed operators: + +- lt - less than +- gt - greater than +- lte - less than or equal to +- gte - greater than or equal to +- ne - not equal +- in - item in list +- isnull - `True` IS NULL, `False` IS NOT NULL +- exact - string equals +- iexact - string equals, case insensitive +- contains - contains string value +- icontains - contains string value, case insensitive +- startswith - starts with string value +- istartswith - starts with string value, case insensitive +- endswith - ends with string value +- iendswith - ends with string value, case insensitive +- regex - matches a regex expression +- iregex - matches a regex expression, case insensitive + +These operators work with both `.get` and `.filter` methods. + +Combining filters +----------------- + +The filter method allows you to combine multiple filters:: + + cheap_arabicas = Coffee.nodes.filter(price__lt=5, name__icontains='arabica') + +These filters are combined using the logical AND operator. To execute more complex logic (for example, queries with OR statements), `Q objects ` can be used. This is borrowed from Django. + +``Q`` objects can be combined using the ``&`` and ``|`` operators. Statements of arbitrary complexity can be composed by combining ``Q`` objects +with the ``&`` and ``|`` operators and use parenthetical grouping. Also, ``Q`` +objects can be negated using the ``~`` operator, allowing for combined lookups +that combine both a normal query and a negated (``NOT``) query:: + + Q(name__icontains='arabica') | ~Q(name__endswith='blend') + +Chaining ``Q`` objects will join them as an AND clause:: + + not_middle_priced_arabicas = Coffee.nodes.filter( + Q(name__icontains='arabica'), + Q(price__lt=5) | Q(price__gt=10) + ) + +Traversals and filtering +------------------------ + +Sometimes you need to filter nodes based on other nodes they are connected to. This can be done by including a traversal in the `filter` method. :: + + # Find all suppliers of coffee 'Java' who have been supplying since 2007 + # But whose prices are greater than 5 + since_date = datetime(2007, 1, 1) + java_old_timers = Coffee.nodes.filter( + name='Java', + suppliers|since__lt=since_date, + suppliers__delivery_cost__gt=5 + ) + +In the example above, note the following syntax elements: + +- The name of relationships as defined in the `StructuredNode` class is used to traverse relationships. `suppliers` in this example. +- Double underscore `__` is used to target a property of a node. `delivery_cost` in this example. +- A pipe `|` is used to separate the relationship traversal from the property filter. This is a special syntax to indicate that the filter is on the relationship itself, not on the node at the end of the relationship. +- The filter operators like lt, gt, etc. can be used on the filtered property. + +Traversals can be of any length, with each relationships separated by a double underscore `__`, for example:: + + # country is here a relationship between Supplier and Country + Coffee.nodes.filter(suppliers__country__name='Brazil') + +Enforcing relationship/path existence +------------------------------------- + +The `has` method checks for existence of (one or more) relationships, in this case it returns a set of `Coffee` nodes which have a supplier:: + + Coffee.nodes.has(suppliers=True) + +This can be negated by setting `suppliers=False`, to find `Coffee` nodes without `suppliers`. + +You can also filter on the existence of more complex traversals by using the `traverse_relations` method. (ADD LINK TO DOC) + +Ordering +======== + +neomodel allows ordering by nodes' and relationships' properties. Order can be ascending or descending. Is also allows multi-hop relationship traversals to order on "remote" elements. Finally, you can inject raw Cypher clauses to have full control over ordering when necessary. + +order_by +-------- + +Ordering results by a particular property is done via th `order_by` method:: + + # Ascending sort + for coffee in Coffee.nodes.order_by('price'): + print(coffee, coffee.price) + + # Descending sort + for supplier in Supplier.nodes.order_by('-delivery_cost'): + print(supplier, supplier.delivery_cost) + + +Removing the ordering from a previously defined query, is done by passing `None` to `order_by`:: + + # Sort in descending order + suppliers = Supplier.nodes.order_by('-delivery_cost') + + # Don't order; yield nodes in the order neo4j returns them + suppliers = suppliers.order_by(None) + +For random ordering simply pass '?' to the order_by method:: + + Coffee.nodes.order_by('?') + +Traversals and ordering +----------------------- + +Sometimes you need to order results based on properties situated on different nodes or relationships. This can be done by including a traversal in the `order_by` method. :: + + # Find the most expensive coffee to deliver + # Then order by the date the supplier started supplying + Coffee.nodes.order_by( + '-suppliers__delivery_cost', + 'suppliers|since', + ) + +In the example above, note the following syntax elements: + +- The name of relationships as defined in the `StructuredNode` class is used to traverse relationships. `suppliers` in this example. +- Double underscore `__` is used to target a property of a node. `delivery_cost` in this example. +- A pipe `|` is used to separate the relationship traversal from the property filter. This is a special syntax to indicate that the filter is on the relationship itself, not on the node at the end of the relationship. + +Traversals can be of any length, with each relationships separated by a double underscore `__`, for example:: + + # country is here a relationship between Supplier and Country + Coffee.nodes.order_by('suppliers__country__latitude') + +RawCypher +--------- + +When you need more advanced ordering capabilities, for example to apply order to a transformed property, you can use the `RawCypher` method, like so:: + + from neomodel.sync_.match import RawCypher + + class SoftwareDependency(AsyncStructuredNode): + name = StringProperty() + version = StringProperty() + + SoftwareDependency(name="Package2", version="1.4.0").save() + SoftwareDependency(name="Package3", version="2.5.5").save() + + latest_dep = SoftwareDependency.nodes.order_by( + RawCypher("toInteger(split($n.version, '.')[0]) DESC"), + ) + +In the example above, note the `$n` placeholder in the `RawCypher` clause. This is a placeholder for the node being ordered (`SoftwareDependency` in this case). diff --git a/doc/source/traversal.rst b/doc/source/traversal.rst new file mode 100644 index 00000000..25931d3c --- /dev/null +++ b/doc/source/traversal.rst @@ -0,0 +1,88 @@ +============== +Path traversal +============== + +Neo4j is about traversing the graph, which means leveraging nodes and relations between them. This section will show you how to traverse the graph using neomodel. + +For the examples in this section, we will be using the following model:: + + class Country(StructuredNode): + country_code = StringProperty(unique_index=True) + name = StringProperty() + + class Supplier(StructuredNode): + name = StringProperty() + delivery_cost = IntegerProperty() + country = RelationshipTo(Country, 'ESTABLISHED_IN') + + class Coffee(StructuredNode): + name = StringProperty(unique_index=True) + price = IntegerProperty() + suppliers = RelationshipFrom(Supplier, 'SUPPLIES') + +Traverse relations +------------------ + +The `traverse_relations` method allows you to filter on the existence of more complex traversals. For example, to find all `Coffee` nodes that have a supplier, and retrieve the country of that supplier, you can do:: + + Coffee.nodes.traverse_relations(country='suppliers__country').all() + +This will generate a Cypher MATCH clause that enforces the existence of at least one path like `Coffee<--Supplier-->Country`. + +The `Country` nodes matched will be made available for the rest of the query, with the variable name `country`. Note that this aliasing is optional. See the section on Advanced query operations for examples of how to use this aliasing. (ADD LINK TO DOC) + +.. note:: + + The `traverse_relations` method can be used to traverse multiple relationships, like:: + + Coffee.nodes.traverse_relations('suppliers__country', 'pub__city').all() + + This will generate a Cypher MATCH clause that enforces the existence of at least one path like `Coffee<--Supplier-->Country` and `Coffee<--Pub-->City`. + +Fetch relations +--------------- + +The syntax for `fetch_relations` is similar to `traverse_relations`, except that the generated Cypher will return all traversed objects (nodes and relations):: + + Coffee.nodes.fetch_relations(country='suppliers__country').all() + +.. note:: + + Any relationship that you intend to traverse using this method **MUST have a model defined**, even if only the default StructuredRel, like:: + + class Person(StructuredNode): + country = RelationshipTo(Country, 'IS_FROM', model=StructuredRel) + + Otherwise, neomodel will not be able to determine which relationship model to resolve into, and will fail. + +Optional match +-------------- + +With both `traverse_relations` and `fetch_relations`, you can force the use of an ``OPTIONAL MATCH`` statement using the following syntax:: + + from neomodel.match import Optional + + # Return the Person nodes, and if they have suppliers, return the suppliers as well + results = Coffee.nodes.fetch_relations(Optional('suppliers')).all() + +.. note:: + + You can fetch one or more relations within the same call + to `.fetch_relations()` and you can mix optional and non-optional + relations, like:: + + Person.nodes.fetch_relations('city__country', Optional('country')).all() + +Resolve results +--------------- + +By default, fetch_relations will return a list of tuples. If your path looks like ``(startNode:Coffee)<-[r1]-(middleNode:Supplier)-[r2]->(endNode:Country)``, +then you will get a list of results, where each result is a list of ``(startNode, r1, middleNode, r2, endNode)``. +These will be resolved by neomodel, so ``startNode`` will be a ``Coffee`` class as defined in neomodel for example. + +Using the `resolve_subgraph` method, you can get instead a list of "subgraphs", where each returned `StructuredNode` element will contain its relations and neighbour nodes. For example:: + + results = Coffee.nodes.fetch_relations('suppliers__country').resolve_subgraph().all() + +In this example, `results[0]` will be a `Coffee` object, with a `_relations` attribute. This will in turn have a `suppliers` and a `suppliers_relationship` attribute, which will contain the `Supplier` object and the relation object respectively. Recursively, the `Supplier` object will have a `country` attribute, which will contain the `Country` object. + From 3da6102a4256974cacb42a34443c0f8a01458af4 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 25 Oct 2024 16:05:55 +0200 Subject: [PATCH 28/42] Add Advanced query operations ; add interlinks --- doc/source/advanced_query_operations.rst | 103 +++++++++ doc/source/cypher.rst | 12 ++ doc/source/filtering_ordering.rst | 2 +- doc/source/getting_started.rst | 54 ++--- doc/source/index.rst | 4 +- doc/source/queries.rst | 258 ----------------------- doc/source/traversal.rst | 4 +- 7 files changed, 144 insertions(+), 293 deletions(-) create mode 100644 doc/source/advanced_query_operations.rst delete mode 100644 doc/source/queries.rst diff --git a/doc/source/advanced_query_operations.rst b/doc/source/advanced_query_operations.rst new file mode 100644 index 00000000..16c93427 --- /dev/null +++ b/doc/source/advanced_query_operations.rst @@ -0,0 +1,103 @@ +.. _Advanced query operations: + +========================= +Advanced query operations +========================= + +neomodel provides ways to enhance your queries beyond filtering and traversals. + +Annotate - Aliasing +------------------- + +The `annotate` method allows you to add transformations to your elements. To learn more about the available transformations, keep reading this section. + +Aggregations +------------ + +neomodel implements some of the aggregation methods available in Cypher: + +- Collect +- Last + +These are usable in this way:: + + from neomodel.sync_match import Collect, Last + + # distinct is optional, and defaults to False. When true, objects are deduplicated + Supplier.nodes.traverse_relations(available_species="coffees__species") + .annotate(Collect("available_species", distinct=True)) + .all() + + # Last is used to get the last element of a list + Supplier.nodes.traverse_relations(available_species="coffees__species") + .annotate(Last(Collect("last_species"))) + .all() + +.. note:: + Using the Last() method right after a Collect() without having set an ordering will return the last element in the list as it was returned by the database. + + This is probably not what you want ; which means you must provide an explicit ordering. To do so, you cannot neomodel's order_by method, but need an intermediate transformation step (see below). + + This is because the order_by method adds ordering as the very last step of the Cypher query ; whereas in the present example, you want to first order Species, then get the last one, and then finally return your results. In other words, you need an intermediate WITH Cypher clause. + +Intermediate transformations +---------------------------- + +The `intermediate_transform` method basically allows you to add a WITH clause to your query. This is useful when you need to perform some operations on your results before returning them. + +As discussed in the note above, this is for example useful when you need to order your results before applying an aggregation method, like so:: + + from neomodel.sync_match import Collect, Last + + # This will return all Coffee nodes, with their most expensive supplier + Coffee.nodes.traverse_relations(suppliers="suppliers") + .intermediate_transform( + {"suppliers": "suppliers"}, ordering=["suppliers.delivery_cost"] + ) + .annotate(supps=Last(Collect("suppliers"))) + +Subqueries +---------- + +The `subquery` method allows you to perform a `Cypher subquery `_ inside your query. This allows you to perform operations in isolation to the rest of your query:: + + from neomodel.sync_match import Collect, Last + + # This will create a CALL{} subquery + # And return a variable named supps usable in the rest of your query + Coffee.nodes.subquery( + Coffee.nodes.traverse_relations(suppliers="suppliers") + .intermediate_transform( + {"suppliers": "suppliers"}, ordering=["suppliers.delivery_cost"] + ) + .annotate(supps=Last(Collect("suppliers"))), + ["supps"], + ) + +Helpers +------- + +Reading the sections above, you may have noticed that we used explicit aliasing in the examples, as in:: + + traverse_relations(suppliers="suppliers") + +This allows you to reference the generated Cypher variables in your transformation steps, for example:: + + traverse_relations(suppliers="suppliers").annotate(Collect("suppliers")) + +In some cases though, it is not possible to set explicit aliases, for example when using `fetch_relations`. In these cases, neomodel provides `resolver` methods, so you do not have to guess the name of the variable in the generated Cypher. Those are `NodeNameResolver` and `RelationshipNameResolver`. For example:: + + from neomodel.sync_match import Collect, NodeNameResolver, RelationshipNameResolver + + Supplier.nodes.fetch_relations("coffees__species") + .annotate( + all_species=Collect(NodeNameResolver("coffees__species"), distinct=True), + all_species_rels=Collect( + RelationNameResolver("coffees__species"), distinct=True + ), + ) + .all() + +.. note:: + + When using the resolvers in combination with a traversal as in the example above, it will resolve the variable name of the last element in the traversal - the Species node for NodeNameResolver, and Coffee--Species relationship for RelationshipNameResolver. \ No newline at end of file diff --git a/doc/source/cypher.rst b/doc/source/cypher.rst index f8c7ccaf..37ebcbf1 100644 --- a/doc/source/cypher.rst +++ b/doc/source/cypher.rst @@ -24,6 +24,18 @@ Outside of a `StructuredNode`:: The ``resolve_objects`` parameter automatically inflates the returned nodes to their defined classes (this is turned **off** by default). See :ref:`automatic_class_resolution` for details and possible pitfalls. +You canalso retrieve a whole path of already instantiated objects corresponding to +the nodes and relationship classes with a single query:: + + q = db.cypher_query("MATCH p=(:CityOfResidence)<-[:LIVES_IN]-(:PersonOfInterest)-[:IS_FROM]->(:CountryOfOrigin) RETURN p LIMIT 1", + resolve_objects = True) + +Notice here that ``resolve_objects`` is set to ``True``. This results in ``q`` being a +list of ``result, result_name`` and ``q[0][0][0]`` being a ``NeomodelPath`` object. + +``NeomodelPath`` ``nodes, relationships`` attributes contain already instantiated objects of the +nodes and relationships in the query, *in order of appearance*. + Integrations ============ diff --git a/doc/source/filtering_ordering.rst b/doc/source/filtering_ordering.rst index 3dee395f..3b1873a0 100644 --- a/doc/source/filtering_ordering.rst +++ b/doc/source/filtering_ordering.rst @@ -120,7 +120,7 @@ The `has` method checks for existence of (one or more) relationships, in this ca This can be negated by setting `suppliers=False`, to find `Coffee` nodes without `suppliers`. -You can also filter on the existence of more complex traversals by using the `traverse_relations` method. (ADD LINK TO DOC) +You can also filter on the existence of more complex traversals by using the `traverse_relations` method. See :ref:`Path traversal`. Ordering ======== diff --git a/doc/source/getting_started.rst b/doc/source/getting_started.rst index 6e8a5aa0..82a0023b 100644 --- a/doc/source/getting_started.rst +++ b/doc/source/getting_started.rst @@ -193,6 +193,28 @@ simply returning the node IDs rather than every attribute associated with that N # Return set of nodes people = Person.nodes.filter(age__gt=3) +Iteration, slicing and more +--------------------------- + +Iteration, slicing and counting is also supported:: + + # Iterable + for coffee in Coffee.nodes: + print coffee.name + + # Sliceable using python slice syntax + coffee = Coffee.nodes.filter(price__gt=2)[2:] + +The slice syntax returns a NodeSet object which can in turn be chained. + +Length and boolean methods do not return NodeSet objects and cannot be chained further:: + + # Count with __len__ + print len(Coffee.nodes.filter(price__gt=2)) + + if Coffee.nodes: + print "We have coffee nodes!" + Relationships ============= @@ -236,38 +258,6 @@ Working with relationships:: Retrieving additional relations =============================== -To avoid queries multiplication, you have the possibility to retrieve -additional relations with a single call:: - - # The following call will generate one MATCH with traversal per - # item in .fetch_relations() call - results = Person.nodes.fetch_relations('country').all() - for result in results: - print(result[0]) # Person - print(result[1]) # associated Country - -You can traverse more than one hop in your relations using the -following syntax:: - - # Go from person to City then Country - Person.nodes.fetch_relations('city__country').all() - -You can also force the use of an ``OPTIONAL MATCH`` statement using -the following syntax:: - - from neomodel.match import Optional - - results = Person.nodes.fetch_relations(Optional('country')).all() - -.. note:: - - Any relationship that you intend to traverse using this method **MUST have a model defined**, even if only the default StructuredRel, like:: - - class Person(StructuredNode): - country = RelationshipTo(Country, 'IS_FROM', model=StructuredRel) - - Otherwise, neomodel will not be able to determine which relationship model to resolve into, and will fail. - .. note:: You can fetch one or more relations within the same call diff --git a/doc/source/index.rst b/doc/source/index.rst index 91a728c0..068e2d93 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -74,7 +74,9 @@ Contents properties spatial_properties schema_management - queries + filtering_ordering + traversal + advanced_query_operations cypher transactions hooks diff --git a/doc/source/queries.rst b/doc/source/queries.rst deleted file mode 100644 index 4c77a791..00000000 --- a/doc/source/queries.rst +++ /dev/null @@ -1,258 +0,0 @@ -================ -Advanced queries -================ - -Neomodel contains an API for querying sets of nodes without having to write cypher:: - - class SupplierRel(StructuredRel): - since = DateTimeProperty(default=datetime.now) - - - class Supplier(StructuredNode): - name = StringProperty() - delivery_cost = IntegerProperty() - coffees = RelationshipTo('Coffee', 'SUPPLIES') - - - class Coffee(StructuredNode): - name = StringProperty(unique_index=True) - price = IntegerProperty() - suppliers = RelationshipFrom(Supplier, 'SUPPLIES', model=SupplierRel) - -Node sets and filtering -======================= - -The ``.nodes`` property of a class returns all nodes of that type from the database. - -This set (or `NodeSet`) can be iterated over and filtered on. Under the hood it uses labels introduced in Neo4J 2:: - - # nodes with label Coffee whose price is greater than 2 - Coffee.nodes.filter(price__gt=2) - - try: - java = Coffee.nodes.get(name='Java') - except Coffee.DoesNotExist: - print "Couldn't find coffee 'Java'" - -The filter method borrows the same Django filter format with double underscore prefixed operators: - -- lt - less than -- gt - greater than -- lte - less than or equal to -- gte - greater than or equal to -- ne - not equal -- in - item in list -- isnull - `True` IS NULL, `False` IS NOT NULL -- exact - string equals -- iexact - string equals, case insensitive -- contains - contains string value -- icontains - contains string value, case insensitive -- startswith - starts with string value -- istartswith - starts with string value, case insensitive -- endswith - ends with string value -- iendswith - ends with string value, case insensitive -- regex - matches a regex expression -- iregex - matches a regex expression, case insensitive - -Complex lookups with ``Q`` objects -================================== - -Keyword argument queries -- in `filter`, -etc. -- are "AND"ed together. To execute more complex queries (for -example, queries with ``OR`` statements), `Q objects ` can -be used. - -A `Q object` (``neomodel.Q``) is an object -used to encapsulate a collection of keyword arguments. These keyword arguments -are specified as in "Field lookups" above. - -For example, this ``Q`` object encapsulates a single ``LIKE`` query:: - - from neomodel import Q - Q(name__startswith='Py') - -``Q`` objects can be combined using the ``&`` and ``|`` operators. When an -operator is used on two ``Q`` objects, it yields a new ``Q`` object. - -For example, this statement yields a single ``Q`` object that represents the -"OR" of two ``"name__startswith"`` queries:: - - Q(name__startswith='Py') | Q(name__startswith='Jav') - -This is equivalent to the following SQL ``WHERE`` clause:: - - WHERE name STARTS WITH 'Py' OR name STARTS WITH 'Jav' - -Statements of arbitrary complexity can be composed by combining ``Q`` objects -with the ``&`` and ``|`` operators and use parenthetical grouping. Also, ``Q`` -objects can be negated using the ``~`` operator, allowing for combined lookups -that combine both a normal query and a negated (``NOT``) query:: - - Q(name__startswith='Py') | ~Q(year=2005) - -Each lookup function that takes keyword-arguments -(e.g. `filter`, `exclude`, `get`) can also be passed one or more -``Q`` objects as positional (not-named) arguments. If multiple -``Q`` object arguments are provided to a lookup function, the arguments will be "AND"ed -together. For example:: - - Lang.nodes.filter( - Q(name__startswith='Py'), - Q(year=2005) | Q(year=2006) - ) - -This roughly translates to the following Cypher query:: - - MATCH (lang:Lang) WHERE name STARTS WITH 'Py' - AND (year = 2005 OR year = 2006) - return lang; - -Lookup functions can mix the use of ``Q`` objects and keyword arguments. All -arguments provided to a lookup function (be they keyword arguments or ``Q`` -objects) are "AND"ed together. However, if a ``Q`` object is provided, it must -precede the definition of any keyword arguments. For example:: - - Lang.nodes.get( - Q(year=2005) | Q(year=2006), - name__startswith='Py', - ) - -This would be a valid query, equivalent to the previous example; - -Has a relationship -================== - -The `has` method checks for existence of (one or more) relationships, in this case it returns a set of `Coffee` nodes which have a supplier:: - - Coffee.nodes.has(suppliers=True) - -This can be negated by setting `suppliers=False`, to find `Coffee` nodes without `suppliers`. - -Iteration, slicing and more -=========================== - -Iteration, slicing and counting is also supported:: - - # Iterable - for coffee in Coffee.nodes: - print coffee.name - - # Sliceable using python slice syntax - coffee = Coffee.nodes.filter(price__gt=2)[2:] - -The slice syntax returns a NodeSet object which can in turn be chained. - -Length and boolean methods dont return NodeSet objects and cannot be chained further:: - - # Count with __len__ - print len(Coffee.nodes.filter(price__gt=2)) - - if Coffee.nodes: - print "We have coffee nodes!" - -Filtering by relationship properties -==================================== - -Filtering on relationship properties is also possible using the `match` method. Note that again these relationships must have a definition.:: - - coffee_brand = Coffee.nodes.get(name="BestCoffeeEver") - - for supplier in coffee_brand.suppliers.match(since_lt=january): - print(supplier.name) - -Ordering by property -==================== - -Ordering results by a particular property is done via th `order_by` method:: - - # Ascending sort - for coffee in Coffee.nodes.order_by('price'): - print(coffee, coffee.price) - - # Descending sort - for supplier in Supplier.nodes.order_by('-delivery_cost'): - print(supplier, supplier.delivery_cost) - - -Removing the ordering from a previously defined query, is done by passing `None` to `order_by`:: - - # Sort in descending order - suppliers = Supplier.nodes.order_by('-delivery_cost') - - # Don't order; yield nodes in the order neo4j returns them - suppliers = suppliers.order_by(None) - -For random ordering simply pass '?' to the order_by method:: - - Coffee.nodes.order_by('?') - -Retrieving paths -================ - -You can retrieve a whole path of already instantiated objects corresponding to -the nodes and relationship classes with a single query. - -Suppose the following schema: - -:: - - class PersonLivesInCity(StructuredRel): - some_num = IntegerProperty(index=True, - default=12) - - class CountryOfOrigin(StructuredNode): - code = StringProperty(unique_index=True, - required=True) - - class CityOfResidence(StructuredNode): - name = StringProperty(required=True) - country = RelationshipTo(CountryOfOrigin, - 'FROM_COUNTRY') - - class PersonOfInterest(StructuredNode): - uid = UniqueIdProperty() - name = StringProperty(unique_index=True) - age = IntegerProperty(index=True, - default=0) - - country = RelationshipTo(CountryOfOrigin, - 'IS_FROM') - city = RelationshipTo(CityOfResidence, - 'LIVES_IN', - model=PersonLivesInCity) - -Then, paths can be retrieved with: - -:: - - q = db.cypher_query("MATCH p=(:CityOfResidence)<-[:LIVES_IN]-(:PersonOfInterest)-[:IS_FROM]->(:CountryOfOrigin) RETURN p LIMIT 1", - resolve_objects = True) - -Notice here that ``resolve_objects`` is set to ``True``. This results in ``q`` being a -list of ``result, result_name`` and ``q[0][0][0]`` being a ``NeomodelPath`` object. - -``NeomodelPath`` ``nodes, relationships`` attributes contain already instantiated objects of the -nodes and relationships in the query, *in order of appearance*. - -It would be particularly useful to note here that each object is read exactly once from -the database. Therefore, nodes will be instantiated to their neomodel node objects and -relationships to their relationship models *if such a model exists*. In other words, -relationships with data (such as ``PersonLivesInCity`` above) will be instantiated to their -respective objects or ``StrucuredRel`` otherwise. Relationships do not "reload" their -end-points (unless this is required). - -Async neomodel - Caveats -======================== - -Python does not support async dunder methods. This means that we had to implement some overrides for those. -See the example below:: - - # This will not work as it uses the synchronous __bool__ method - assert await Customer.nodes.filter(prop="value") - - # Do this instead - assert await Customer.nodes.filter(prop="value").check_bool() - assert await Customer.nodes.filter(prop="value").check_nonzero() - - # Note : no changes are needed for sync so this still works : - assert Customer.nodes.filter(prop="value") diff --git a/doc/source/traversal.rst b/doc/source/traversal.rst index 25931d3c..4cbb2fd4 100644 --- a/doc/source/traversal.rst +++ b/doc/source/traversal.rst @@ -1,3 +1,5 @@ +.. _Path traversal: + ============== Path traversal ============== @@ -29,7 +31,7 @@ The `traverse_relations` method allows you to filter on the existence of more co This will generate a Cypher MATCH clause that enforces the existence of at least one path like `Coffee<--Supplier-->Country`. -The `Country` nodes matched will be made available for the rest of the query, with the variable name `country`. Note that this aliasing is optional. See the section on Advanced query operations for examples of how to use this aliasing. (ADD LINK TO DOC) +The `Country` nodes matched will be made available for the rest of the query, with the variable name `country`. Note that this aliasing is optional. See :ref:`Advanced query operations` for examples of how to use this aliasing. .. note:: From 974d90c4c63443f490e51ae885729d39ce79a630 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 28 Oct 2024 09:29:51 +0100 Subject: [PATCH 29/42] Fix based on comments --- doc/source/advanced_query_operations.rst | 6 +++--- doc/source/cypher.rst | 2 +- doc/source/filtering_ordering.rst | 8 ++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/doc/source/advanced_query_operations.rst b/doc/source/advanced_query_operations.rst index 16c93427..8cef8a1e 100644 --- a/doc/source/advanced_query_operations.rst +++ b/doc/source/advanced_query_operations.rst @@ -21,7 +21,7 @@ neomodel implements some of the aggregation methods available in Cypher: These are usable in this way:: - from neomodel.sync_match import Collect, Last + from neomodel.sync_.match import Collect, Last # distinct is optional, and defaults to False. When true, objects are deduplicated Supplier.nodes.traverse_relations(available_species="coffees__species") @@ -36,7 +36,7 @@ These are usable in this way:: .. note:: Using the Last() method right after a Collect() without having set an ordering will return the last element in the list as it was returned by the database. - This is probably not what you want ; which means you must provide an explicit ordering. To do so, you cannot neomodel's order_by method, but need an intermediate transformation step (see below). + This is probably not what you want ; which means you must provide an explicit ordering. To do so, you cannot use neomodel's `order_by` method, but need an intermediate transformation step (see below). This is because the order_by method adds ordering as the very last step of the Cypher query ; whereas in the present example, you want to first order Species, then get the last one, and then finally return your results. In other words, you need an intermediate WITH Cypher clause. @@ -47,7 +47,7 @@ The `intermediate_transform` method basically allows you to add a WITH clause to As discussed in the note above, this is for example useful when you need to order your results before applying an aggregation method, like so:: - from neomodel.sync_match import Collect, Last + from neomodel.sync_.match import Collect, Last # This will return all Coffee nodes, with their most expensive supplier Coffee.nodes.traverse_relations(suppliers="suppliers") diff --git a/doc/source/cypher.rst b/doc/source/cypher.rst index 37ebcbf1..8ce2a42e 100644 --- a/doc/source/cypher.rst +++ b/doc/source/cypher.rst @@ -24,7 +24,7 @@ Outside of a `StructuredNode`:: The ``resolve_objects`` parameter automatically inflates the returned nodes to their defined classes (this is turned **off** by default). See :ref:`automatic_class_resolution` for details and possible pitfalls. -You canalso retrieve a whole path of already instantiated objects corresponding to +You can also retrieve a whole path of already instantiated objects corresponding to the nodes and relationship classes with a single query:: q = db.cypher_query("MATCH p=(:CityOfResidence)<-[:LIVES_IN]-(:PersonOfInterest)-[:IS_FROM]->(:CountryOfOrigin) RETURN p LIMIT 1", diff --git a/doc/source/filtering_ordering.rst b/doc/source/filtering_ordering.rst index 3b1873a0..795ccef9 100644 --- a/doc/source/filtering_ordering.rst +++ b/doc/source/filtering_ordering.rst @@ -95,15 +95,15 @@ Sometimes you need to filter nodes based on other nodes they are connected to. T since_date = datetime(2007, 1, 1) java_old_timers = Coffee.nodes.filter( name='Java', - suppliers|since__lt=since_date, - suppliers__delivery_cost__gt=5 + suppliers__delivery_cost__gt=5, + **{"suppliers|since__lt": since_date} ) In the example above, note the following syntax elements: - The name of relationships as defined in the `StructuredNode` class is used to traverse relationships. `suppliers` in this example. - Double underscore `__` is used to target a property of a node. `delivery_cost` in this example. -- A pipe `|` is used to separate the relationship traversal from the property filter. This is a special syntax to indicate that the filter is on the relationship itself, not on the node at the end of the relationship. +- A pipe `|` is used to separate the relationship traversal from the property filter. The filter also has to included in a `**kwargs` dictionary, because the pipe character would break the syntax. This is a special syntax to indicate that the filter is on the relationship itself, not on the node at the end of the relationship. - The filter operators like lt, gt, etc. can be used on the filtered property. Traversals can be of any length, with each relationships separated by a double underscore `__`, for example:: @@ -130,7 +130,7 @@ neomodel allows ordering by nodes' and relationships' properties. Order can be a order_by -------- -Ordering results by a particular property is done via th `order_by` method:: +Ordering results by a particular property is done via the `order_by` method:: # Ascending sort for coffee in Coffee.nodes.order_by('price'): From cf3d18bbdb488c06e9809c907ffb51d66261e046 Mon Sep 17 00:00:00 2001 From: Daniyar Irishev Date: Fri, 1 Nov 2024 09:50:11 +0900 Subject: [PATCH 30/42] add tests --- test/async_/test_properties.py | 11 +++++++++++ test/sync_/test_properties.py | 11 +++++++++++ 2 files changed, 22 insertions(+) diff --git a/test/async_/test_properties.py b/test/async_/test_properties.py index 4f3eab2d..6949372d 100644 --- a/test/async_/test_properties.py +++ b/test/async_/test_properties.py @@ -310,6 +310,17 @@ def test_json(): assert prop.inflate('{"test": [1, 2, 3]}') == value +def test_json_unicode(): + prop = JSONProperty(ensure_ascii=False) + prop.name = "json" + prop.owner = FooBar + + value = {"test": [1, 2, 3, "©"]} + + assert prop.deflate(value) == '{"test": [1, 2, 3, "©"]}' + assert prop.inflate('{"test": [1, 2, 3, ©]}') == value + + def test_indexed(): indexed = StringProperty(index=True) assert indexed.is_indexed is True diff --git a/test/sync_/test_properties.py b/test/sync_/test_properties.py index 1afe52a2..b61182fb 100644 --- a/test/sync_/test_properties.py +++ b/test/sync_/test_properties.py @@ -304,6 +304,17 @@ def test_json(): assert prop.inflate('{"test": [1, 2, 3]}') == value +def test_json_unicode(): + prop = JSONProperty(ensure_ascii=False) + prop.name = "json" + prop.owner = FooBar + + value = {"test": [1, 2, 3, "©"]} + + assert prop.deflate(value) == '{"test": [1, 2, 3, "©"]}' + assert prop.inflate('{"test": [1, 2, 3, ©]}') == value + + def test_indexed(): indexed = StringProperty(index=True) assert indexed.is_indexed is True From 7a36673ac7fedbc6409c1186453469e872f0c441 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 4 Nov 2024 10:12:03 +0100 Subject: [PATCH 31/42] Add full example in test - Temp, code non functional --- doc/source/advanced_query_operations.rst | 4 +- doc/source/traversal.rst | 6 +- test/async_/test_match_api.py | 140 +++++++++++++++++++++-- test/sync_/test_match_api.py | 138 ++++++++++++++++++++-- 4 files changed, 272 insertions(+), 16 deletions(-) diff --git a/doc/source/advanced_query_operations.rst b/doc/source/advanced_query_operations.rst index 8cef8a1e..6a2b5466 100644 --- a/doc/source/advanced_query_operations.rst +++ b/doc/source/advanced_query_operations.rst @@ -16,7 +16,7 @@ Aggregations neomodel implements some of the aggregation methods available in Cypher: -- Collect +- Collect (with distinct option) - Last These are usable in this way:: @@ -33,6 +33,8 @@ These are usable in this way:: .annotate(Last(Collect("last_species"))) .all() +Note how `annotate` is used to add the aggregation method to the query. + .. note:: Using the Last() method right after a Collect() without having set an ordering will return the last element in the list as it was returned by the database. diff --git a/doc/source/traversal.rst b/doc/source/traversal.rst index 4cbb2fd4..e4d94b34 100644 --- a/doc/source/traversal.rst +++ b/doc/source/traversal.rst @@ -78,7 +78,7 @@ With both `traverse_relations` and `fetch_relations`, you can force the use of a Resolve results --------------- -By default, fetch_relations will return a list of tuples. If your path looks like ``(startNode:Coffee)<-[r1]-(middleNode:Supplier)-[r2]->(endNode:Country)``, +By default, `fetch_relations` will return a list of tuples. If your path looks like ``(startNode:Coffee)<-[r1]-(middleNode:Supplier)-[r2]->(endNode:Country)``, then you will get a list of results, where each result is a list of ``(startNode, r1, middleNode, r2, endNode)``. These will be resolved by neomodel, so ``startNode`` will be a ``Coffee`` class as defined in neomodel for example. @@ -88,3 +88,7 @@ Using the `resolve_subgraph` method, you can get instead a list of "subgraphs", In this example, `results[0]` will be a `Coffee` object, with a `_relations` attribute. This will in turn have a `suppliers` and a `suppliers_relationship` attribute, which will contain the `Supplier` object and the relation object respectively. Recursively, the `Supplier` object will have a `country` attribute, which will contain the `Country` object. +.. note:: + + The `resolve_subgraph` method is only available for `fetch_relations` queries. This is because `traverse_relations` queries do not return any relations, and thus there is no need to resolve them. + diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 81944871..ab4bb9cf 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -91,6 +91,30 @@ class SoftwareDependency(AsyncStructuredNode): version = StringProperty(required=True) +class HasCourseRel(AsyncStructuredRel): + level = StringProperty() + start_date = DateTimeProperty() + end_date = DateTimeProperty() + + +class Course(AsyncStructuredNode): + name = StringProperty() + + +class Building(AsyncStructuredNode): + name = StringProperty() + + +class Student(AsyncStructuredNode): + name = StringProperty() + + parents = AsyncRelationshipTo("Student", "HAS_PARENT") + children = AsyncRelationshipFrom("Student", "HAS_PARENT") + lives_in = AsyncRelationshipTo(Building, "LIVES_IN") + has_course = AsyncRelationshipTo(Course, "HAS_COURSE", model=HasCourseRel) + has_latest_course = AsyncRelationshipTo(Course, "HAS_COURSE") + + @mark_async_test async def test_filter_exclude_via_labels(): await Coffee(name="Java", price=99).save() @@ -557,7 +581,7 @@ async def test_traversal_filter_left_hand_statement(): nescafe = await Coffee(name="Nescafe2", price=99).save() nescafe_gold = await Coffee(name="Nescafe gold", price=11).save() - tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = await Supplier(name="Tesco", delivery_cost=3).save() biedronka = await Supplier(name="Biedronka", delivery_cost=5).save() lidl = await Supplier(name="Lidl", delivery_cost=3).save() @@ -583,7 +607,7 @@ async def test_filter_with_traversal(): robusta = await Species(name="Robusta").save() nescafe = await Coffee(name="Nescafe", price=11).save() nescafe_gold = await Coffee(name="Nescafe Gold", price=99).save() - tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = await Supplier(name="Tesco", delivery_cost=3).save() await nescafe.suppliers.connect(tesco) await nescafe_gold.suppliers.connect(tesco) await nescafe.species.connect(arabica) @@ -594,6 +618,15 @@ async def test_filter_with_traversal(): assert len(results[0]) == 3 assert results[0][0] == nescafe + results_multi_hop = await Supplier.nodes.filter( + coffees__species__name="Arabica" + ).all() + assert len(results_multi_hop) == 1 + assert results_multi_hop[0][0] == tesco + + no_results = await Supplier.nodes.filter(coffees__species__name="Noffee").all() + assert no_results == [] + @mark_async_test async def test_relation_prop_filtering(): @@ -616,6 +649,16 @@ async def test_relation_prop_filtering(): assert len(results) == 1 assert results[0][0] == supplier1 + # Test it works with mixed argument syntaxes + results2 = await Supplier.nodes.filter( + name="Supplier 1", + coffees__name="Nescafe", + **{"coffees|since__gt": datetime(2018, 4, 1, 0, 0)}, + ).all() + + assert len(results2) == 1 + assert results2[0][0] == supplier1 + @mark_async_test async def test_relation_prop_ordering(): @@ -656,7 +699,7 @@ async def test_fetch_relations(): nescafe = await Coffee(name="Nescafe", price=99).save() nescafe_gold = await Coffee(name="Nescafe Gold", price=11).save() - tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = await Supplier(name="Tesco", delivery_cost=3).save() await nescafe.suppliers.connect(tesco) await nescafe_gold.suppliers.connect(tesco) await nescafe.species.connect(arabica) @@ -715,7 +758,7 @@ async def test_traverse_and_order_by(): robusta = await Species(name="Robusta").save() nescafe = await Coffee(name="Nescafe", price=99).save() nescafe_gold = await Coffee(name="Nescafe Gold", price=110).save() - tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = await Supplier(name="Tesco", delivery_cost=3).save() await nescafe.suppliers.connect(tesco) await nescafe_gold.suppliers.connect(tesco) await nescafe.species.connect(arabica) @@ -740,7 +783,7 @@ async def test_annotate_and_collect(): nescafe = await Coffee(name="Nescafe 1002", price=99).save() nescafe_gold = await Coffee(name="Nescafe 1003", price=11).save() - tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = await Supplier(name="Tesco", delivery_cost=3).save() await nescafe.suppliers.connect(tesco) await nescafe_gold.suppliers.connect(tesco) await nescafe.species.connect(arabica) @@ -793,7 +836,7 @@ async def test_resolve_subgraph(): nescafe = await Coffee(name="Nescafe", price=99).save() nescafe_gold = await Coffee(name="Nescafe Gold", price=11).save() - tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = await Supplier(name="Tesco", delivery_cost=3).save() await nescafe.suppliers.connect(tesco) await nescafe_gold.suppliers.connect(tesco) await nescafe.species.connect(arabica) @@ -842,7 +885,7 @@ async def test_resolve_subgraph_optional(): nescafe = await Coffee(name="Nescafe", price=99).save() nescafe_gold = await Coffee(name="Nescafe Gold", price=11).save() - tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = await Supplier(name="Tesco", delivery_cost=3).save() await nescafe.suppliers.connect(tesco) await nescafe_gold.suppliers.connect(tesco) await nescafe.species.connect(arabica) @@ -952,6 +995,89 @@ async def test_intermediate_transform(): ) +@mark_async_test +async def test_mix_functions(): + # Test with a mix of all advanced querying functions + + eiffel_tower = await Building(name="Eiffel Tower").save() + empire_state_building = await Building(name="Empire State Building").save() + miranda = await Student(name="Miranda").save() + await miranda.lives_in.connect(empire_state_building) + jean_pierre = await Student(name="Jean-Pierre").save() + await jean_pierre.lives_in.connect(eiffel_tower) + mireille = await Student(name="Mireille").save() + mimoun_jr = await Student(name="Mimoun Jr").save() + mimoun = await Student(name="Mimoun").save() + await mireille.lives_in.connect(eiffel_tower) + await mimoun_jr.lives_in.connect(eiffel_tower) + await mimoun.lives_in.connect(eiffel_tower) + await mimoun.parents.connect(mireille) + await mimoun.children.connect(mimoun_jr) + course = await Course(name="Math").save() + await mimoun.has_course.connect( + course, + { + "level": "1.2", + "start_date": datetime(2020, 6, 2), + "end_date": datetime(2020, 12, 31), + }, + ) + await mimoun.has_course.connect( + course, + { + "level": "1.1", + "start_date": datetime(2020, 1, 1), + "end_date": datetime(2020, 6, 1), + }, + ) + await mimoun_jr.has_course.connect( + course, + { + "level": "1.1", + "start_date": datetime(2020, 1, 1), + "end_date": datetime(2020, 6, 1), + }, + ) + + filtered_nodeset = Student.nodes.filter( + name__istartswith="m", lives_in__name="Eiffel Tower" + ) + full_nodeset = ( + await filtered_nodeset.order_by("name") + .traverse_relations( + "parents", + ) + .fetch_relations( + "lives_in", + Optional("children__has_latest_course"), + ) + .subquery( + filtered_nodeset.order_by("name") + .fetch_relations("has_course") + .intermediate_transform( + {"rel": RelationNameResolver("has_course")}, + ordering=[ + RawCypher("toInteger(split(rel.level, '.')[0])"), + RawCypher("toInteger(split(rel.level, '.')[1])"), + "rel.end_date", + "rel.start_date", + ], + ) + .annotate( + latest_course=Last(Collect("rel")), + ), + ["latest_course"], + ) + ) + + subgraph = await full_nodeset.annotate( + Collect(NodeNameResolver("children"), distinct=True), + Collect(NodeNameResolver("children__has_latest_course"), distinct=True), + ).resolve_subgraph() + + print(subgraph) + + @mark_async_test async def test_issue_795(): jim = await PersonX(name="Jim", age=3).save() # Create diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index e47e3396..139ad3b4 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -89,6 +89,30 @@ class SoftwareDependency(StructuredNode): version = StringProperty(required=True) +class HasCourseRel(StructuredRel): + level = StringProperty() + start_date = DateTimeProperty() + end_date = DateTimeProperty() + + +class Course(StructuredNode): + name = StringProperty() + + +class Building(StructuredNode): + name = StringProperty() + + +class Student(StructuredNode): + name = StringProperty() + + parents = RelationshipTo("Student", "HAS_PARENT") + children = RelationshipFrom("Student", "HAS_PARENT") + lives_in = RelationshipTo(Building, "LIVES_IN") + has_course = RelationshipTo(Course, "HAS_COURSE", model=HasCourseRel) + has_latest_course = RelationshipTo(Course, "HAS_COURSE") + + @mark_sync_test def test_filter_exclude_via_labels(): Coffee(name="Java", price=99).save() @@ -553,7 +577,7 @@ def test_traversal_filter_left_hand_statement(): nescafe = Coffee(name="Nescafe2", price=99).save() nescafe_gold = Coffee(name="Nescafe gold", price=11).save() - tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = Supplier(name="Tesco", delivery_cost=3).save() biedronka = Supplier(name="Biedronka", delivery_cost=5).save() lidl = Supplier(name="Lidl", delivery_cost=3).save() @@ -577,7 +601,7 @@ def test_filter_with_traversal(): robusta = Species(name="Robusta").save() nescafe = Coffee(name="Nescafe", price=11).save() nescafe_gold = Coffee(name="Nescafe Gold", price=99).save() - tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = Supplier(name="Tesco", delivery_cost=3).save() nescafe.suppliers.connect(tesco) nescafe_gold.suppliers.connect(tesco) nescafe.species.connect(arabica) @@ -588,6 +612,13 @@ def test_filter_with_traversal(): assert len(results[0]) == 3 assert results[0][0] == nescafe + results_multi_hop = Supplier.nodes.filter(coffees__species__name="Arabica").all() + assert len(results_multi_hop) == 1 + assert results_multi_hop[0][0] == tesco + + no_results = Supplier.nodes.filter(coffees__species__name="Noffee").all() + assert no_results == [] + @mark_sync_test def test_relation_prop_filtering(): @@ -610,6 +641,16 @@ def test_relation_prop_filtering(): assert len(results) == 1 assert results[0][0] == supplier1 + # Test it works with mixed argument syntaxes + results2 = Supplier.nodes.filter( + name="Supplier 1", + coffees__name="Nescafe", + **{"coffees|since__gt": datetime(2018, 4, 1, 0, 0)}, + ).all() + + assert len(results2) == 1 + assert results2[0][0] == supplier1 + @mark_sync_test def test_relation_prop_ordering(): @@ -646,7 +687,7 @@ def test_fetch_relations(): nescafe = Coffee(name="Nescafe", price=99).save() nescafe_gold = Coffee(name="Nescafe Gold", price=11).save() - tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = Supplier(name="Tesco", delivery_cost=3).save() nescafe.suppliers.connect(tesco) nescafe_gold.suppliers.connect(tesco) nescafe.species.connect(arabica) @@ -705,7 +746,7 @@ def test_traverse_and_order_by(): robusta = Species(name="Robusta").save() nescafe = Coffee(name="Nescafe", price=99).save() nescafe_gold = Coffee(name="Nescafe Gold", price=110).save() - tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = Supplier(name="Tesco", delivery_cost=3).save() nescafe.suppliers.connect(tesco) nescafe_gold.suppliers.connect(tesco) nescafe.species.connect(arabica) @@ -728,7 +769,7 @@ def test_annotate_and_collect(): nescafe = Coffee(name="Nescafe 1002", price=99).save() nescafe_gold = Coffee(name="Nescafe 1003", price=11).save() - tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = Supplier(name="Tesco", delivery_cost=3).save() nescafe.suppliers.connect(tesco) nescafe_gold.suppliers.connect(tesco) nescafe.species.connect(arabica) @@ -781,7 +822,7 @@ def test_resolve_subgraph(): nescafe = Coffee(name="Nescafe", price=99).save() nescafe_gold = Coffee(name="Nescafe Gold", price=11).save() - tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = Supplier(name="Tesco", delivery_cost=3).save() nescafe.suppliers.connect(tesco) nescafe_gold.suppliers.connect(tesco) nescafe.species.connect(arabica) @@ -830,7 +871,7 @@ def test_resolve_subgraph_optional(): nescafe = Coffee(name="Nescafe", price=99).save() nescafe_gold = Coffee(name="Nescafe Gold", price=11).save() - tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = Supplier(name="Tesco", delivery_cost=3).save() nescafe.suppliers.connect(tesco) nescafe_gold.suppliers.connect(tesco) nescafe.species.connect(arabica) @@ -940,6 +981,89 @@ def test_intermediate_transform(): ) +@mark_sync_test +def test_mix_functions(): + # Test with a mix of all advanced querying functions + + eiffel_tower = Building(name="Eiffel Tower").save() + empire_state_building = Building(name="Empire State Building").save() + miranda = Student(name="Miranda").save() + miranda.lives_in.connect(empire_state_building) + jean_pierre = Student(name="Jean-Pierre").save() + jean_pierre.lives_in.connect(eiffel_tower) + mireille = Student(name="Mireille").save() + mimoun_jr = Student(name="Mimoun Jr").save() + mimoun = Student(name="Mimoun").save() + mireille.lives_in.connect(eiffel_tower) + mimoun_jr.lives_in.connect(eiffel_tower) + mimoun.lives_in.connect(eiffel_tower) + mimoun.parents.connect(mireille) + mimoun.children.connect(mimoun_jr) + course = Course(name="Math").save() + mimoun.has_course.connect( + course, + { + "level": "1.2", + "start_date": datetime(2020, 6, 2), + "end_date": datetime(2020, 12, 31), + }, + ) + mimoun.has_course.connect( + course, + { + "level": "1.1", + "start_date": datetime(2020, 1, 1), + "end_date": datetime(2020, 6, 1), + }, + ) + mimoun_jr.has_course.connect( + course, + { + "level": "1.1", + "start_date": datetime(2020, 1, 1), + "end_date": datetime(2020, 6, 1), + }, + ) + + filtered_nodeset = Student.nodes.filter( + name__istartswith="m", lives_in__name="Eiffel Tower" + ) + full_nodeset = ( + filtered_nodeset.order_by("name") + .traverse_relations( + "parents", + ) + .fetch_relations( + "lives_in", + Optional("children__has_latest_course"), + ) + .subquery( + filtered_nodeset.order_by("name") + .fetch_relations("has_course") + .intermediate_transform( + {"rel": RelationNameResolver("has_course")}, + ordering=[ + RawCypher("toInteger(split(rel.level, '.')[0])"), + RawCypher("toInteger(split(rel.level, '.')[1])"), + "rel.end_date", + "rel.start_date", + ], + ) + .annotate( + latest_course=Last(Collect("rel")), + ), + ["latest_course"], + ) + ) + + subgraph = full_nodeset.annotate( + Collect(NodeNameResolver("children"), distinct=True), + Collect(NodeNameResolver("children__has_latest_course"), distinct=True), + ).resolve_subgraph() + + print(subgraph) + + @mark_sync_test def test_issue_795(): jim = PersonX(name="Jim", age=3).save() # Create From 1d927bc718c474868a68e9361a0328794e9a46ac Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Mon, 4 Nov 2024 10:44:07 +0100 Subject: [PATCH 32/42] Fixed double query execution --- neomodel/async_/match.py | 3 +-- neomodel/sync_/match.py | 3 +-- test/async_/test_match_api.py | 4 +--- test/sync_/test_match_api.py | 4 +--- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index b34938c6..8f103c4f 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -1499,10 +1499,9 @@ async def resolve_subgraph(self) -> list: raise RuntimeError( "Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()." ) - all_nodes = qbuilder._execute(dict_output=True) other_nodes = {} root_node = None - async for row in all_nodes: + async for row in qbuilder._execute(dict_output=True): for name, node in row.items(): if node.__class__ is self.source and "_" not in name: root_node = node diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 73715cc8..1bdfe660 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -1499,10 +1499,9 @@ def resolve_subgraph(self) -> list: raise RuntimeError( "Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()." ) - all_nodes = qbuilder._execute(dict_output=True) other_nodes = {} root_node = None - for row in all_nodes: + for row in qbuilder._execute(dict_output=True): for name, node in row.items(): if node.__class__ is self.source and "_" not in name: root_node = node diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index ab4bb9cf..bed4f877 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -1048,12 +1048,10 @@ async def test_mix_functions(): "parents", ) .fetch_relations( - "lives_in", Optional("children__has_latest_course"), ) .subquery( - filtered_nodeset.order_by("name") - .fetch_relations("has_course") + Student.nodes.fetch_relations("has_course") .intermediate_transform( {"rel": RelationNameResolver("has_course")}, ordering=[ diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 139ad3b4..0ba00b38 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -1034,12 +1034,10 @@ def test_mix_functions(): "parents", ) .fetch_relations( - "lives_in", Optional("children__has_latest_course"), ) .subquery( - filtered_nodeset.order_by("name") - .fetch_relations("has_course") + Student.nodes.fetch_relations("has_course") .intermediate_transform( {"rel": RelationNameResolver("has_course")}, ordering=[ From 1b26793ac63629301684df020d723ab22d2c65bf Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 4 Nov 2024 12:26:19 +0100 Subject: [PATCH 33/42] Add full example ; fix example in test --- doc/source/advanced_query_operations.rst | 8 ++- doc/source/filtering_ordering.rst | 2 + doc/source/getting_started.rst | 90 ++++++++++++++++++++++++ doc/source/traversal.rst | 2 + test/async_/test_match_api.py | 64 ++++++++++------- test/sync_/test_match_api.py | 64 ++++++++++------- 6 files changed, 179 insertions(+), 51 deletions(-) diff --git a/doc/source/advanced_query_operations.rst b/doc/source/advanced_query_operations.rst index 6a2b5466..e602d479 100644 --- a/doc/source/advanced_query_operations.rst +++ b/doc/source/advanced_query_operations.rst @@ -67,7 +67,8 @@ The `subquery` method allows you to perform a `Cypher subquery (student_parents:Student) + MATCH (student)-[r4:`LIVES_IN`]->(building_lives_in:Building) + OPTIONAL MATCH (student)<-[r2:`HAS_PARENT`]-(student_children:Student)-[r3:`HAS_PREFERRED_COURSE`]->(course_children__preferred_course:Course) + WITH * + # building_lives_in_name_1 = "Eiffel Tower" + # student_name_1 = "(?i)m.*" + WHERE building_lives_in.name = $building_lives_in_name_1 AND student.name =~ $student_name_1 + CALL { + WITH student + MATCH (student)-[r1:`HAS_COURSE`]->(course_courses:Course) + WITH r1 AS rel + ORDER BY toInteger(split(rel.level, '.')[0]),toInteger(split(rel.level, '.')[1]),rel.end_date,rel.start_date + RETURN last(collect(rel)) AS latest_course + } + RETURN latest_course, student, student_parents, r1, student_children, r2, course_children__preferred_course, r3, building_lives_in, r4, collect(DISTINCT student_children) AS children, collect(DISTINCT course_children__preferred_course) AS children_preferred_course + ORDER BY student.name + """ \ No newline at end of file diff --git a/doc/source/traversal.rst b/doc/source/traversal.rst index e4d94b34..c6347ad0 100644 --- a/doc/source/traversal.rst +++ b/doc/source/traversal.rst @@ -6,6 +6,8 @@ Path traversal Neo4j is about traversing the graph, which means leveraging nodes and relations between them. This section will show you how to traverse the graph using neomodel. +We will cover two methods : `traverse_relations` and `fetch_relations`. Those two methods are *mutually exclusive*, so you cannot chain them. + For the examples in this section, we will be using the following model:: class Country(StructuredNode): diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index bed4f877..56d5f37f 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -11,6 +11,7 @@ AsyncRelationshipTo, AsyncStructuredNode, AsyncStructuredRel, + AsyncZeroOrOne, DateTimeProperty, IntegerProperty, Q, @@ -108,11 +109,16 @@ class Building(AsyncStructuredNode): class Student(AsyncStructuredNode): name = StringProperty() - parents = AsyncRelationshipTo("Student", "HAS_PARENT") - children = AsyncRelationshipFrom("Student", "HAS_PARENT") - lives_in = AsyncRelationshipTo(Building, "LIVES_IN") - has_course = AsyncRelationshipTo(Course, "HAS_COURSE", model=HasCourseRel) - has_latest_course = AsyncRelationshipTo(Course, "HAS_COURSE") + parents = AsyncRelationshipTo("Student", "HAS_PARENT", model=AsyncStructuredRel) + children = AsyncRelationshipFrom("Student", "HAS_PARENT", model=AsyncStructuredRel) + lives_in = AsyncRelationshipTo(Building, "LIVES_IN", model=AsyncStructuredRel) + courses = AsyncRelationshipTo(Course, "HAS_COURSE", model=HasCourseRel) + preferred_course = AsyncRelationshipTo( + Course, + "HAS_PREFERRED_COURSE", + model=AsyncStructuredRel, + cardinality=AsyncZeroOrOne, + ) @mark_async_test @@ -1013,25 +1019,26 @@ async def test_mix_functions(): await mimoun.lives_in.connect(eiffel_tower) await mimoun.parents.connect(mireille) await mimoun.children.connect(mimoun_jr) - course = await Course(name="Math").save() - await mimoun.has_course.connect( - course, + math = await Course(name="Math").save() + dessin = await Course(name="Dessin").save() + await mimoun.courses.connect( + math, { "level": "1.2", "start_date": datetime(2020, 6, 2), "end_date": datetime(2020, 12, 31), }, ) - await mimoun.has_course.connect( - course, + await mimoun.courses.connect( + math, { "level": "1.1", "start_date": datetime(2020, 1, 1), "end_date": datetime(2020, 6, 1), }, ) - await mimoun_jr.has_course.connect( - course, + await mimoun_jr.courses.connect( + math, { "level": "1.1", "start_date": datetime(2020, 1, 1), @@ -1039,21 +1046,19 @@ async def test_mix_functions(): }, ) - filtered_nodeset = Student.nodes.filter( - name__istartswith="m", lives_in__name="Eiffel Tower" - ) + await mimoun_jr.preferred_course.connect(dessin) + full_nodeset = ( - await filtered_nodeset.order_by("name") - .traverse_relations( - "parents", - ) + await Student.nodes.filter(name__istartswith="m", lives_in__name="Eiffel Tower") + .order_by("name") .fetch_relations( - Optional("children__has_latest_course"), + "parents", + Optional("children__preferred_course"), ) .subquery( - Student.nodes.fetch_relations("has_course") + Student.nodes.fetch_relations("courses") .intermediate_transform( - {"rel": RelationNameResolver("has_course")}, + {"rel": RelationNameResolver("courses")}, ordering=[ RawCypher("toInteger(split(rel.level, '.')[0])"), RawCypher("toInteger(split(rel.level, '.')[1])"), @@ -1069,11 +1074,20 @@ async def test_mix_functions(): ) subgraph = await full_nodeset.annotate( - Collect(NodeNameResolver("children"), distinct=True), - Collect(NodeNameResolver("children__has_latest_course"), distinct=True), + children=Collect(NodeNameResolver("children"), distinct=True), + children_preferred_course=Collect( + NodeNameResolver("children__preferred_course"), distinct=True + ), ).resolve_subgraph() - print(subgraph) + assert len(subgraph) == 2 + assert subgraph[0] == mimoun + assert subgraph[1] == mimoun_jr + mimoun_returned_rels = subgraph[0]._relations + assert mimoun_returned_rels["children"] == mimoun_jr + assert mimoun_returned_rels["children"]._relations["preferred_course"] == dessin + assert mimoun_returned_rels["parents"] == mireille + assert mimoun_returned_rels["latest_course_relationship"].level == "1.2" @mark_async_test diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 0ba00b38..60b48554 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -16,6 +16,7 @@ StructuredNode, StructuredRel, UniqueIdProperty, + ZeroOrOne, db, ) from neomodel._async_compat.util import Util @@ -106,11 +107,16 @@ class Building(StructuredNode): class Student(StructuredNode): name = StringProperty() - parents = RelationshipTo("Student", "HAS_PARENT") - children = RelationshipFrom("Student", "HAS_PARENT") - lives_in = RelationshipTo(Building, "LIVES_IN") - has_course = RelationshipTo(Course, "HAS_COURSE", model=HasCourseRel) - has_latest_course = RelationshipTo(Course, "HAS_COURSE") + parents = RelationshipTo("Student", "HAS_PARENT", model=StructuredRel) + children = RelationshipFrom("Student", "HAS_PARENT", model=StructuredRel) + lives_in = RelationshipTo(Building, "LIVES_IN", model=StructuredRel) + courses = RelationshipTo(Course, "HAS_COURSE", model=HasCourseRel) + preferred_course = RelationshipTo( + Course, + "HAS_PREFERRED_COURSE", + model=StructuredRel, + cardinality=ZeroOrOne, + ) @mark_sync_test @@ -999,25 +1005,26 @@ def test_mix_functions(): mimoun.lives_in.connect(eiffel_tower) mimoun.parents.connect(mireille) mimoun.children.connect(mimoun_jr) - course = Course(name="Math").save() - mimoun.has_course.connect( - course, + math = Course(name="Math").save() + dessin = Course(name="Dessin").save() + mimoun.courses.connect( + math, { "level": "1.2", "start_date": datetime(2020, 6, 2), "end_date": datetime(2020, 12, 31), }, ) - mimoun.has_course.connect( - course, + mimoun.courses.connect( + math, { "level": "1.1", "start_date": datetime(2020, 1, 1), "end_date": datetime(2020, 6, 1), }, ) - mimoun_jr.has_course.connect( - course, + mimoun_jr.courses.connect( + math, { "level": "1.1", "start_date": datetime(2020, 1, 1), @@ -1025,21 +1032,19 @@ def test_mix_functions(): }, ) - filtered_nodeset = Student.nodes.filter( - name__istartswith="m", lives_in__name="Eiffel Tower" - ) + mimoun_jr.preferred_course.connect(dessin) + full_nodeset = ( - filtered_nodeset.order_by("name") - .traverse_relations( - "parents", - ) + Student.nodes.filter(name__istartswith="m", lives_in__name="Eiffel Tower") + .order_by("name") .fetch_relations( - Optional("children__has_latest_course"), + "parents", + Optional("children__preferred_course"), ) .subquery( - Student.nodes.fetch_relations("has_course") + Student.nodes.fetch_relations("courses") .intermediate_transform( - {"rel": RelationNameResolver("has_course")}, + {"rel": RelationNameResolver("courses")}, ordering=[ RawCypher("toInteger(split(rel.level, '.')[0])"), RawCypher("toInteger(split(rel.level, '.')[1])"), @@ -1055,11 +1060,20 @@ def test_mix_functions(): ) subgraph = full_nodeset.annotate( - Collect(NodeNameResolver("children"), distinct=True), - Collect(NodeNameResolver("children__has_latest_course"), distinct=True), + children=Collect(NodeNameResolver("children"), distinct=True), + children_preferred_course=Collect( + NodeNameResolver("children__preferred_course"), distinct=True + ), ).resolve_subgraph() - print(subgraph) + assert len(subgraph) == 2 + assert subgraph[0] == mimoun + assert subgraph[1] == mimoun_jr + mimoun_returned_rels = subgraph[0]._relations + assert mimoun_returned_rels["children"] == mimoun_jr + assert mimoun_returned_rels["children"]._relations["preferred_course"] == dessin + assert mimoun_returned_rels["parents"] == mireille + assert mimoun_returned_rels["latest_course_relationship"].level == "1.2" @mark_sync_test From 64ccf2072c4bd0bd2fd6608efdfaf1d2b31fe14d Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 4 Nov 2024 13:16:33 +0100 Subject: [PATCH 34/42] Fix docstring --- neomodel/properties.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neomodel/properties.py b/neomodel/properties.py index 8c848ea8..0cd716bd 100644 --- a/neomodel/properties.py +++ b/neomodel/properties.py @@ -418,7 +418,7 @@ class DateTimeFormatProperty(Property): """ Store a datetime by custom format :param default_now: If ``True``, the creation time (Local) will be used as default. - Defaults to ``False``. + Defaults to ``False``. :param format: Date format string, default is %Y-%m-%d :type default_now: :class:`bool` From d6169ed65d1da02bee74e33dee2a86f1a2bb5740 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 4 Nov 2024 13:21:40 +0100 Subject: [PATCH 35/42] Fix test --- test/async_/test_match_api.py | 14 +++++++------- test/sync_/test_match_api.py | 16 +++++++--------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 56d5f37f..c73d5b19 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -186,7 +186,7 @@ async def test_get(): @mark_async_test async def test_simple_traverse_with_filter(): nescafe = await Coffee(name="Nescafe2", price=99).save() - tesco = await Supplier(name="Sainsburys", delivery_cost=2).save() + tesco = await Supplier(name="Tesco", delivery_cost=2).save() await nescafe.suppliers.connect(tesco) qb = AsyncQueryBuilder( @@ -200,7 +200,7 @@ async def test_simple_traverse_with_filter(): assert qb._ast.match assert qb._ast.return_clause.startswith("suppliers") assert len(results) == 1 - assert results[0].name == "Sainsburys" + assert results[0].name == "Tesco" @mark_async_test @@ -711,7 +711,7 @@ async def test_fetch_relations(): await nescafe.species.connect(arabica) result = ( - await Supplier.nodes.filter(name="Sainsburys") + await Supplier.nodes.filter(name="Tesco") .fetch_relations("coffees__species") .all() ) @@ -731,7 +731,7 @@ async def test_fetch_relations(): if AsyncUtil.is_async_code: count = ( - await Supplier.nodes.filter(name="Sainsburys") + await Supplier.nodes.filter(name="Tesco") .fetch_relations("coffees__species") .get_len() ) @@ -739,19 +739,19 @@ async def test_fetch_relations(): assert ( await Supplier.nodes.fetch_relations("coffees__species") - .filter(name="Sainsburys") + .filter(name="Tesco") .check_contains(tesco) ) else: count = len( - Supplier.nodes.filter(name="Sainsburys") + Supplier.nodes.filter(name="Tesco") .fetch_relations("coffees__species") .all() ) assert count == 1 assert tesco in Supplier.nodes.fetch_relations("coffees__species").filter( - name="Sainsburys" + name="Tesco" ) diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 60b48554..f4bfe7dc 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -184,7 +184,7 @@ def test_get(): @mark_sync_test def test_simple_traverse_with_filter(): nescafe = Coffee(name="Nescafe2", price=99).save() - tesco = Supplier(name="Sainsburys", delivery_cost=2).save() + tesco = Supplier(name="Tesco", delivery_cost=2).save() nescafe.suppliers.connect(tesco) qb = QueryBuilder(NodeSet(source=nescafe).suppliers.match(since__lt=datetime.now())) @@ -196,7 +196,7 @@ def test_simple_traverse_with_filter(): assert qb._ast.match assert qb._ast.return_clause.startswith("suppliers") assert len(results) == 1 - assert results[0].name == "Sainsburys" + assert results[0].name == "Tesco" @mark_sync_test @@ -699,9 +699,7 @@ def test_fetch_relations(): nescafe.species.connect(arabica) result = ( - Supplier.nodes.filter(name="Sainsburys") - .fetch_relations("coffees__species") - .all() + Supplier.nodes.filter(name="Tesco").fetch_relations("coffees__species").all() ) assert len(result[0]) == 5 assert arabica in result[0] @@ -719,7 +717,7 @@ def test_fetch_relations(): if Util.is_async_code: count = ( - Supplier.nodes.filter(name="Sainsburys") + Supplier.nodes.filter(name="Tesco") .fetch_relations("coffees__species") .__len__() ) @@ -727,19 +725,19 @@ def test_fetch_relations(): assert ( Supplier.nodes.fetch_relations("coffees__species") - .filter(name="Sainsburys") + .filter(name="Tesco") .__contains__(tesco) ) else: count = len( - Supplier.nodes.filter(name="Sainsburys") + Supplier.nodes.filter(name="Tesco") .fetch_relations("coffees__species") .all() ) assert count == 1 assert tesco in Supplier.nodes.fetch_relations("coffees__species").filter( - name="Sainsburys" + name="Tesco" ) From 02208a95604f8a55576bb7246a96542a853c10a3 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Mon, 4 Nov 2024 17:00:20 +0100 Subject: [PATCH 36/42] Added automatic DB cleanup after each test --- pyproject.toml | 1 + test/_async_compat/__init__.py | 4 +++ test/_async_compat/mark_decorator.py | 10 ++++-- test/async_/conftest.py | 25 ++++++------- test/async_/test_issue283.py | 25 ------------- test/async_/test_issue600.py | 10 ------ test/async_/test_match_api.py | 53 ---------------------------- test/async_/test_paths.py | 9 ----- test/async_/test_properties.py | 17 --------- test/async_/test_transactions.py | 14 -------- test/sync_/conftest.py | 15 +++++--- test/sync_/test_issue283.py | 25 ------------- test/sync_/test_issue600.py | 10 ------ test/sync_/test_match_api.py | 53 ---------------------------- test/sync_/test_paths.py | 9 ----- test/sync_/test_properties.py | 17 --------- test/sync_/test_transactions.py | 14 -------- 17 files changed, 34 insertions(+), 277 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3ad12f86..ce83e943 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ where = ["./"] [tool.pytest.ini_options] addopts = "--resetdb" testpaths = "test" +asyncio_default_fixture_loop_scope = "session" [tool.isort] profile = 'black' diff --git a/test/_async_compat/__init__.py b/test/_async_compat/__init__.py index 342678c3..5bdc28e3 100644 --- a/test/_async_compat/__init__.py +++ b/test/_async_compat/__init__.py @@ -1,8 +1,10 @@ from .mark_decorator import ( AsyncTestDecorators, TestDecorators, + mark_async_function_auto_fixture, mark_async_session_auto_fixture, mark_async_test, + mark_sync_function_auto_fixture, mark_sync_session_auto_fixture, mark_sync_test, ) @@ -13,5 +15,7 @@ "mark_sync_test", "TestDecorators", "mark_async_session_auto_fixture", + "mark_async_function_auto_fixture", "mark_sync_session_auto_fixture", + "mark_sync_function_auto_fixture", ] diff --git a/test/_async_compat/mark_decorator.py b/test/_async_compat/mark_decorator.py index a8c5eead..5d6050d8 100644 --- a/test/_async_compat/mark_decorator.py +++ b/test/_async_compat/mark_decorator.py @@ -1,9 +1,15 @@ import pytest import pytest_asyncio -mark_async_test = pytest.mark.asyncio -mark_async_session_auto_fixture = pytest_asyncio.fixture(scope="session", autouse=True) +mark_async_test = pytest.mark.asyncio(loop_scope="session") +mark_async_session_auto_fixture = pytest_asyncio.fixture( + loop_scope="session", scope="session", autouse=True +) +mark_async_function_auto_fixture = pytest_asyncio.fixture( + loop_scope="session", autouse=True +) mark_sync_session_auto_fixture = pytest.fixture(scope="session", autouse=True) +mark_sync_function_auto_fixture = pytest.fixture(autouse=True) def mark_sync_test(f): diff --git a/test/async_/conftest.py b/test/async_/conftest.py index 493ff12c..8cbf952b 100644 --- a/test/async_/conftest.py +++ b/test/async_/conftest.py @@ -1,15 +1,15 @@ -import asyncio import os import warnings -from test._async_compat import mark_async_session_auto_fixture - -import pytest +from test._async_compat import ( + mark_async_function_auto_fixture, + mark_async_session_auto_fixture, +) from neomodel import adb, config @mark_async_session_auto_fixture -async def setup_neo4j_session(request, event_loop): +async def setup_neo4j_session(request): """ Provides initial connection to the database and sets up the rest of the test suite @@ -44,17 +44,12 @@ async def setup_neo4j_session(request, event_loop): await adb.cypher_query("GRANT ROLE publisher TO troygreene") await adb.cypher_query("GRANT IMPERSONATE (troygreene) ON DBMS TO admin") - -@mark_async_session_auto_fixture -async def cleanup(event_loop): yield + await adb.close_connection() -@pytest.fixture(scope="session") -def event_loop(): - """Overrides pytest default function scoped event loop""" - policy = asyncio.get_event_loop_policy() - loop = policy.new_event_loop() - yield loop - loop.close() +@mark_async_function_auto_fixture +async def setUp(): + await adb.cypher_query("MATCH (n) DETACH DELETE n") + yield diff --git a/test/async_/test_issue283.py b/test/async_/test_issue283.py index 8106a796..6682b765 100644 --- a/test/async_/test_issue283.py +++ b/test/async_/test_issue283.py @@ -122,10 +122,6 @@ async def test_automatic_result_resolution(): # TechnicalPerson (!NOT basePerson!) assert type((await A.friends_with)[0]) is TechnicalPerson - await A.delete() - await B.delete() - await C.delete() - @mark_async_test async def test_recursive_automatic_result_resolution(): @@ -176,11 +172,6 @@ async def test_recursive_automatic_result_resolution(): # Assert that primitive data types remain primitive data types assert issubclass(type(L[0][0][0][1][0][1][0][1][0][0]), basestring) - await A.delete() - await B.delete() - await C.delete() - await D.delete() - @mark_async_test async def test_validation_with_inheritance_from_db(): @@ -240,12 +231,6 @@ async def test_validation_with_inheritance_from_db(): ) assert type((await D.friends_with)[0]) is PilotPerson - await A.delete() - await B.delete() - await C.delete() - await D.delete() - await E.delete() - @mark_async_test async def test_validation_enforcement_to_db(): @@ -295,13 +280,6 @@ async def test_validation_enforcement_to_db(): with pytest.raises(ValueError): await A.friends_with.connect(F) - await A.delete() - await B.delete() - await C.delete() - await D.delete() - await E.delete() - await F.delete() - @mark_async_test async def test_failed_result_resolution(): @@ -344,9 +322,6 @@ class RandomPerson(BasePerson): for some_friend in friends: print(some_friend.name) - await A.delete() - await B.delete() - @mark_async_test async def test_node_label_mismatch(): diff --git a/test/async_/test_issue600.py b/test/async_/test_issue600.py index 5f66f39e..3cf4e870 100644 --- a/test/async_/test_issue600.py +++ b/test/async_/test_issue600.py @@ -63,11 +63,6 @@ async def test_relationship_definer_second_sibling(): await B.rel_2.connect(C) await C.rel_3.connect(A) - # Clean up - await A.delete() - await B.delete() - await C.delete() - @mark_async_test async def test_relationship_definer_parent_last(): @@ -80,8 +75,3 @@ async def test_relationship_definer_parent_last(): await A.rel_1.connect(B) await B.rel_2.connect(C) await C.rel_3.connect(A) - - # Clean up - await A.delete() - await B.delete() - await C.delete() diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index c73d5b19..39e96957 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -253,9 +253,6 @@ async def test_len_and_iter_and_bool(): @mark_async_test async def test_slice(): - for c in await Coffee.nodes: - await c.delete() - await Coffee(name="Icelands finest").save() await Coffee(name="Britains finest").save() await Coffee(name="Japans finest").save() @@ -324,9 +321,6 @@ async def test_contains(): @mark_async_test async def test_order_by(): - # Clean DB before we start anything... - await adb.cypher_query("MATCH (n) DETACH DELETE n") - c1 = await Coffee(name="Icelands finest", price=5).save() c2 = await Coffee(name="Britains finest", price=10).save() c3 = await Coffee(name="Japans finest", price=35).save() @@ -369,9 +363,6 @@ async def test_order_by(): @mark_async_test async def test_order_by_rawcypher(): - # Clean DB before we start anything... - await adb.cypher_query("MATCH (n) DETACH DELETE n") - d1 = await SoftwareDependency(name="Package1", version="1.0.0").save() d2 = await SoftwareDependency(name="Package2", version="1.4.0").save() d3 = await SoftwareDependency(name="Package3", version="2.5.5").save() @@ -392,9 +383,6 @@ async def test_order_by_rawcypher(): @mark_async_test async def test_extra_filters(): - for c in await Coffee.nodes: - await c.delete() - c1 = await Coffee(name="Icelands finest", price=5, id_=1).save() c2 = await Coffee(name="Britains finest", price=10, id_=2).save() c3 = await Coffee(name="Japans finest", price=35, id_=3).save() @@ -466,10 +454,6 @@ async def test_empty_filters(): ``get_queryset`` function in ``GenericAPIView`` should returns ``NodeSet`` object. """ - - for c in await Coffee.nodes: - await c.delete() - c1 = await Coffee(name="Super", price=5, id_=1).save() c2 = await Coffee(name="Puper", price=10, id_=2).save() @@ -493,10 +477,6 @@ async def test_empty_filters(): @mark_async_test async def test_q_filters(): - # Test where no children and self.connector != conn ? - for c in await Coffee.nodes: - await c.delete() - c1 = await Coffee(name="Icelands finest", price=5, id_=1).save() c2 = await Coffee(name="Britains finest", price=10, id_=2).save() c3 = await Coffee(name="Japans finest", price=35, id_=3).save() @@ -606,9 +586,6 @@ async def test_traversal_filter_left_hand_statement(): @mark_async_test async def test_filter_with_traversal(): - # Clean DB before we start anything... - await adb.cypher_query("MATCH (n) DETACH DELETE n") - arabica = await Species(name="Arabica").save() robusta = await Species(name="Robusta").save() nescafe = await Coffee(name="Nescafe", price=11).save() @@ -636,9 +613,6 @@ async def test_filter_with_traversal(): @mark_async_test async def test_relation_prop_filtering(): - # Clean DB before we start anything... - await adb.cypher_query("MATCH (n) DETACH DELETE n") - arabica = await Species(name="Arabica").save() nescafe = await Coffee(name="Nescafe", price=99).save() supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save() @@ -668,9 +642,6 @@ async def test_relation_prop_filtering(): @mark_async_test async def test_relation_prop_ordering(): - # Clean DB before we start anything... - await adb.cypher_query("MATCH (n) DETACH DELETE n") - arabica = await Species(name="Arabica").save() nescafe = await Coffee(name="Nescafe", price=99).save() supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save() @@ -697,9 +668,6 @@ async def test_relation_prop_ordering(): @mark_async_test async def test_fetch_relations(): - # Clean DB before we start anything... - await adb.cypher_query("MATCH (n) DETACH DELETE n") - arabica = await Species(name="Arabica").save() robusta = await Species(name="Robusta").save() nescafe = await Coffee(name="Nescafe", price=99).save() @@ -757,9 +725,6 @@ async def test_fetch_relations(): @mark_async_test async def test_traverse_and_order_by(): - # Clean DB before we start anything... - await adb.cypher_query("MATCH (n) DETACH DELETE n") - arabica = await Species(name="Arabica").save() robusta = await Species(name="Robusta").save() nescafe = await Coffee(name="Nescafe", price=99).save() @@ -781,9 +746,6 @@ async def test_traverse_and_order_by(): @mark_async_test async def test_annotate_and_collect(): - # Clean DB before we start anything... - await adb.cypher_query("MATCH (n) DETACH DELETE n") - arabica = await Species(name="Arabica").save() robusta = await Species(name="Robusta").save() nescafe = await Coffee(name="Nescafe 1002", price=99).save() @@ -834,9 +796,6 @@ async def test_annotate_and_collect(): @mark_async_test async def test_resolve_subgraph(): - # Clean DB before we start anything... - await adb.cypher_query("MATCH (n) DETACH DELETE n") - arabica = await Species(name="Arabica").save() robusta = await Species(name="Robusta").save() nescafe = await Coffee(name="Nescafe", price=99).save() @@ -884,9 +843,6 @@ async def test_resolve_subgraph(): @mark_async_test async def test_resolve_subgraph_optional(): - # Clean DB before we start anything... - await adb.cypher_query("MATCH (n) DETACH DELETE n") - arabica = await Species(name="Arabica").save() nescafe = await Coffee(name="Nescafe", price=99).save() nescafe_gold = await Coffee(name="Nescafe Gold", price=11).save() @@ -911,9 +867,6 @@ async def test_resolve_subgraph_optional(): @mark_async_test async def test_subquery(): - # Clean DB before we start anything... - await adb.cypher_query("MATCH (n) DETACH DELETE n") - arabica = await Species(name="Arabica").save() nescafe = await Coffee(name="Nescafe", price=99).save() supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save() @@ -950,9 +903,6 @@ async def test_subquery(): @mark_async_test async def test_intermediate_transform(): - # Clean DB before we start anything... - await adb.cypher_query("MATCH (n) DETACH DELETE n") - arabica = await Species(name="Arabica").save() nescafe = await Coffee(name="Nescafe", price=99).save() supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save() @@ -1133,9 +1083,6 @@ async def test_in_filter_with_array_property(): async def test_async_iterator(): n = 10 if AsyncUtil.is_async_code: - for c in await Coffee.nodes: - await c.delete() - for i in range(n): await Coffee(name=f"xxx_{i}", price=i).save() diff --git a/test/async_/test_paths.py b/test/async_/test_paths.py index 59a5e385..f0599e01 100644 --- a/test/async_/test_paths.py +++ b/test/async_/test_paths.py @@ -85,12 +85,3 @@ async def test_path_instantiation(): assert type(path_rels[0]) is PersonLivesInCity assert type(path_rels[1]) is AsyncStructuredRel - - await c1.delete() - await c2.delete() - await ct1.delete() - await ct2.delete() - await p1.delete() - await p2.delete() - await p3.delete() - await p4.delete() diff --git a/test/async_/test_properties.py b/test/async_/test_properties.py index 4f3eab2d..1e8f0c44 100644 --- a/test/async_/test_properties.py +++ b/test/async_/test_properties.py @@ -418,10 +418,6 @@ async def test_independent_property_name(): rel = await x.knows.relationship(x) assert rel.known_for == r.known_for - # -- cleanup -- - - await x.delete() - @mark_async_test async def test_independent_property_name_for_semi_structured(): @@ -455,8 +451,6 @@ class TestDBNamePropertySemiStructuredNode(AsyncSemiStructuredNode): # assert not hasattr(from_get, "title") assert from_get.extra == "data" - await semi.delete() - @mark_async_test async def test_independent_property_name_get_or_create(): @@ -475,9 +469,6 @@ class TestNode(AsyncStructuredNode): assert node_properties["name"] == "jim" assert "name_" not in node_properties - # delete node afterwards - await x.delete() - @mark.parametrize("normalized_class", (NormalizedProperty,)) def test_normalized_property(normalized_class): @@ -648,9 +639,6 @@ class ConstrainedTestNode(AsyncStructuredNode): node_properties = get_graph_entity_properties(results[0][0]) assert node_properties["unique_required_property"] == "unique and required" - # delete node afterwards - await x.delete() - @mark_async_test async def test_unique_index_prop_enforced(): @@ -675,11 +663,6 @@ class UniqueNullableNameNode(AsyncStructuredNode): results, _ = await adb.cypher_query("MATCH (n:UniqueNullableNameNode) RETURN n") assert len(results) == 3 - # Delete nodes afterwards - await x.delete() - await y.delete() - await z.delete() - def test_alias_property(): class AliasedClass(AsyncStructuredNode): diff --git a/test/async_/test_transactions.py b/test/async_/test_transactions.py index 59d523c5..de7a13e5 100644 --- a/test/async_/test_transactions.py +++ b/test/async_/test_transactions.py @@ -14,9 +14,6 @@ class APerson(AsyncStructuredNode): @mark_async_test async def test_rollback_and_commit_transaction(): - for p in await APerson.nodes: - await p.delete() - await APerson(name="Roger").save() await adb.begin() @@ -41,8 +38,6 @@ async def in_a_tx(*names): @mark_async_test async def test_transaction_decorator(): await adb.install_labels(APerson) - for p in await APerson.nodes: - await p.delete() # should work await in_a_tx("Roger") @@ -68,9 +63,6 @@ async def test_transaction_as_a_context(): @mark_async_test async def test_query_inside_transaction(): - for p in await APerson.nodes: - await p.delete() - async with adb.transaction: await APerson(name="Alice").save() await APerson(name="Bob").save() @@ -119,9 +111,6 @@ async def in_a_tx_with_bookmark(*names): @mark_async_test async def test_bookmark_transaction_decorator(): - for p in await APerson.nodes: - await p.delete() - # should work result, bookmarks = await in_a_tx_with_bookmark("Ruth", bookmarks=None) assert result is None @@ -181,9 +170,6 @@ async def test_bookmark_passed_in_to_context(spy_on_db_begin): @mark_async_test async def test_query_inside_bookmark_transaction(): - for p in await APerson.nodes: - await p.delete() - async with adb.transaction as transaction: await APerson(name="Alice").save() await APerson(name="Bob").save() diff --git a/test/sync_/conftest.py b/test/sync_/conftest.py index d2cd787e..cbe38140 100644 --- a/test/sync_/conftest.py +++ b/test/sync_/conftest.py @@ -1,6 +1,9 @@ import os import warnings -from test._async_compat import mark_sync_session_auto_fixture +from test._async_compat import ( + mark_async_function_auto_fixture, + mark_sync_session_auto_fixture, +) from neomodel import config, db @@ -41,8 +44,12 @@ def setup_neo4j_session(request): db.cypher_query("GRANT ROLE publisher TO troygreene") db.cypher_query("GRANT IMPERSONATE (troygreene) ON DBMS TO admin") - -@mark_sync_session_auto_fixture -def cleanup(): yield + db.close_connection() + + +@mark_async_function_auto_fixture +def setUp(): + db.cypher_query("MATCH (n) DETACH DELETE n") + yield diff --git a/test/sync_/test_issue283.py b/test/sync_/test_issue283.py index 611431ce..842e21e5 100644 --- a/test/sync_/test_issue283.py +++ b/test/sync_/test_issue283.py @@ -117,10 +117,6 @@ def test_automatic_result_resolution(): # TechnicalPerson (!NOT basePerson!) assert type((A.friends_with)[0]) is TechnicalPerson - A.delete() - B.delete() - C.delete() - @mark_sync_test def test_recursive_automatic_result_resolution(): @@ -159,11 +155,6 @@ def test_recursive_automatic_result_resolution(): # Assert that primitive data types remain primitive data types assert issubclass(type(L[0][0][0][1][0][1][0][1][0][0]), basestring) - A.delete() - B.delete() - C.delete() - D.delete() - @mark_sync_test def test_validation_with_inheritance_from_db(): @@ -217,12 +208,6 @@ def test_validation_with_inheritance_from_db(): ) assert type((D.friends_with)[0]) is PilotPerson - A.delete() - B.delete() - C.delete() - D.delete() - E.delete() - @mark_sync_test def test_validation_enforcement_to_db(): @@ -266,13 +251,6 @@ def test_validation_enforcement_to_db(): with pytest.raises(ValueError): A.friends_with.connect(F) - A.delete() - B.delete() - C.delete() - D.delete() - E.delete() - F.delete() - @mark_sync_test def test_failed_result_resolution(): @@ -311,9 +289,6 @@ class RandomPerson(BasePerson): for some_friend in friends: print(some_friend.name) - A.delete() - B.delete() - @mark_sync_test def test_node_label_mismatch(): diff --git a/test/sync_/test_issue600.py b/test/sync_/test_issue600.py index f6b5a10b..181a156d 100644 --- a/test/sync_/test_issue600.py +++ b/test/sync_/test_issue600.py @@ -63,11 +63,6 @@ def test_relationship_definer_second_sibling(): B.rel_2.connect(C) C.rel_3.connect(A) - # Clean up - A.delete() - B.delete() - C.delete() - @mark_sync_test def test_relationship_definer_parent_last(): @@ -80,8 +75,3 @@ def test_relationship_definer_parent_last(): A.rel_1.connect(B) B.rel_2.connect(C) C.rel_3.connect(A) - - # Clean up - A.delete() - B.delete() - C.delete() diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index f4bfe7dc..78909860 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -249,9 +249,6 @@ def test_len_and_iter_and_bool(): @mark_sync_test def test_slice(): - for c in Coffee.nodes: - c.delete() - Coffee(name="Icelands finest").save() Coffee(name="Britains finest").save() Coffee(name="Japans finest").save() @@ -320,9 +317,6 @@ def test_contains(): @mark_sync_test def test_order_by(): - # Clean DB before we start anything... - db.cypher_query("MATCH (n) DETACH DELETE n") - c1 = Coffee(name="Icelands finest", price=5).save() c2 = Coffee(name="Britains finest", price=10).save() c3 = Coffee(name="Japans finest", price=35).save() @@ -365,9 +359,6 @@ def test_order_by(): @mark_sync_test def test_order_by_rawcypher(): - # Clean DB before we start anything... - db.cypher_query("MATCH (n) DETACH DELETE n") - d1 = SoftwareDependency(name="Package1", version="1.0.0").save() d2 = SoftwareDependency(name="Package2", version="1.4.0").save() d3 = SoftwareDependency(name="Package3", version="2.5.5").save() @@ -388,9 +379,6 @@ def test_order_by_rawcypher(): @mark_sync_test def test_extra_filters(): - for c in Coffee.nodes: - c.delete() - c1 = Coffee(name="Icelands finest", price=5, id_=1).save() c2 = Coffee(name="Britains finest", price=10, id_=2).save() c3 = Coffee(name="Japans finest", price=35, id_=3).save() @@ -462,10 +450,6 @@ def test_empty_filters(): ``get_queryset`` function in ``GenericAPIView`` should returns ``NodeSet`` object. """ - - for c in Coffee.nodes: - c.delete() - c1 = Coffee(name="Super", price=5, id_=1).save() c2 = Coffee(name="Puper", price=10, id_=2).save() @@ -489,10 +473,6 @@ def test_empty_filters(): @mark_sync_test def test_q_filters(): - # Test where no children and self.connector != conn ? - for c in Coffee.nodes: - c.delete() - c1 = Coffee(name="Icelands finest", price=5, id_=1).save() c2 = Coffee(name="Britains finest", price=10, id_=2).save() c3 = Coffee(name="Japans finest", price=35, id_=3).save() @@ -600,9 +580,6 @@ def test_traversal_filter_left_hand_statement(): @mark_sync_test def test_filter_with_traversal(): - # Clean DB before we start anything... - db.cypher_query("MATCH (n) DETACH DELETE n") - arabica = Species(name="Arabica").save() robusta = Species(name="Robusta").save() nescafe = Coffee(name="Nescafe", price=11).save() @@ -628,9 +605,6 @@ def test_filter_with_traversal(): @mark_sync_test def test_relation_prop_filtering(): - # Clean DB before we start anything... - db.cypher_query("MATCH (n) DETACH DELETE n") - arabica = Species(name="Arabica").save() nescafe = Coffee(name="Nescafe", price=99).save() supplier1 = Supplier(name="Supplier 1", delivery_cost=3).save() @@ -660,9 +634,6 @@ def test_relation_prop_filtering(): @mark_sync_test def test_relation_prop_ordering(): - # Clean DB before we start anything... - db.cypher_query("MATCH (n) DETACH DELETE n") - arabica = Species(name="Arabica").save() nescafe = Coffee(name="Nescafe", price=99).save() supplier1 = Supplier(name="Supplier 1", delivery_cost=3).save() @@ -685,9 +656,6 @@ def test_relation_prop_ordering(): @mark_sync_test def test_fetch_relations(): - # Clean DB before we start anything... - db.cypher_query("MATCH (n) DETACH DELETE n") - arabica = Species(name="Arabica").save() robusta = Species(name="Robusta").save() nescafe = Coffee(name="Nescafe", price=99).save() @@ -743,9 +711,6 @@ def test_fetch_relations(): @mark_sync_test def test_traverse_and_order_by(): - # Clean DB before we start anything... - db.cypher_query("MATCH (n) DETACH DELETE n") - arabica = Species(name="Arabica").save() robusta = Species(name="Robusta").save() nescafe = Coffee(name="Nescafe", price=99).save() @@ -765,9 +730,6 @@ def test_traverse_and_order_by(): @mark_sync_test def test_annotate_and_collect(): - # Clean DB before we start anything... - db.cypher_query("MATCH (n) DETACH DELETE n") - arabica = Species(name="Arabica").save() robusta = Species(name="Robusta").save() nescafe = Coffee(name="Nescafe 1002", price=99).save() @@ -818,9 +780,6 @@ def test_annotate_and_collect(): @mark_sync_test def test_resolve_subgraph(): - # Clean DB before we start anything... - db.cypher_query("MATCH (n) DETACH DELETE n") - arabica = Species(name="Arabica").save() robusta = Species(name="Robusta").save() nescafe = Coffee(name="Nescafe", price=99).save() @@ -868,9 +827,6 @@ def test_resolve_subgraph(): @mark_sync_test def test_resolve_subgraph_optional(): - # Clean DB before we start anything... - db.cypher_query("MATCH (n) DETACH DELETE n") - arabica = Species(name="Arabica").save() nescafe = Coffee(name="Nescafe", price=99).save() nescafe_gold = Coffee(name="Nescafe Gold", price=11).save() @@ -895,9 +851,6 @@ def test_resolve_subgraph_optional(): @mark_sync_test def test_subquery(): - # Clean DB before we start anything... - db.cypher_query("MATCH (n) DETACH DELETE n") - arabica = Species(name="Arabica").save() nescafe = Coffee(name="Nescafe", price=99).save() supplier1 = Supplier(name="Supplier 1", delivery_cost=3).save() @@ -934,9 +887,6 @@ def test_subquery(): @mark_sync_test def test_intermediate_transform(): - # Clean DB before we start anything... - db.cypher_query("MATCH (n) DETACH DELETE n") - arabica = Species(name="Arabica").save() nescafe = Coffee(name="Nescafe", price=99).save() supplier1 = Supplier(name="Supplier 1", delivery_cost=3).save() @@ -1117,9 +1067,6 @@ def test_in_filter_with_array_property(): def test_async_iterator(): n = 10 if Util.is_async_code: - for c in Coffee.nodes: - c.delete() - for i in range(n): Coffee(name=f"xxx_{i}", price=i).save() diff --git a/test/sync_/test_paths.py b/test/sync_/test_paths.py index 8e0ccf90..1a6429bf 100644 --- a/test/sync_/test_paths.py +++ b/test/sync_/test_paths.py @@ -85,12 +85,3 @@ def test_path_instantiation(): assert type(path_rels[0]) is PersonLivesInCity assert type(path_rels[1]) is StructuredRel - - c1.delete() - c2.delete() - ct1.delete() - ct2.delete() - p1.delete() - p2.delete() - p3.delete() - p4.delete() diff --git a/test/sync_/test_properties.py b/test/sync_/test_properties.py index 1afe52a2..53ae0002 100644 --- a/test/sync_/test_properties.py +++ b/test/sync_/test_properties.py @@ -408,10 +408,6 @@ def test_independent_property_name(): rel = x.knows.relationship(x) assert rel.known_for == r.known_for - # -- cleanup -- - - x.delete() - @mark_sync_test def test_independent_property_name_for_semi_structured(): @@ -445,8 +441,6 @@ class TestDBNamePropertySemiStructuredNode(SemiStructuredNode): # assert not hasattr(from_get, "title") assert from_get.extra == "data" - semi.delete() - @mark_sync_test def test_independent_property_name_get_or_create(): @@ -465,9 +459,6 @@ class TestNode(StructuredNode): assert node_properties["name"] == "jim" assert "name_" not in node_properties - # delete node afterwards - x.delete() - @mark.parametrize("normalized_class", (NormalizedProperty,)) def test_normalized_property(normalized_class): @@ -638,9 +629,6 @@ class ConstrainedTestNode(StructuredNode): node_properties = get_graph_entity_properties(results[0][0]) assert node_properties["unique_required_property"] == "unique and required" - # delete node afterwards - x.delete() - @mark_sync_test def test_unique_index_prop_enforced(): @@ -665,11 +653,6 @@ class UniqueNullableNameNode(StructuredNode): results, _ = db.cypher_query("MATCH (n:UniqueNullableNameNode) RETURN n") assert len(results) == 3 - # Delete nodes afterwards - x.delete() - y.delete() - z.delete() - def test_alias_property(): class AliasedClass(StructuredNode): diff --git a/test/sync_/test_transactions.py b/test/sync_/test_transactions.py index 834b538e..71ce479f 100644 --- a/test/sync_/test_transactions.py +++ b/test/sync_/test_transactions.py @@ -14,9 +14,6 @@ class APerson(StructuredNode): @mark_sync_test def test_rollback_and_commit_transaction(): - for p in APerson.nodes: - p.delete() - APerson(name="Roger").save() db.begin() @@ -41,8 +38,6 @@ def in_a_tx(*names): @mark_sync_test def test_transaction_decorator(): db.install_labels(APerson) - for p in APerson.nodes: - p.delete() # should work in_a_tx("Roger") @@ -68,9 +63,6 @@ def test_transaction_as_a_context(): @mark_sync_test def test_query_inside_transaction(): - for p in APerson.nodes: - p.delete() - with db.transaction: APerson(name="Alice").save() APerson(name="Bob").save() @@ -119,9 +111,6 @@ def in_a_tx_with_bookmark(*names): @mark_sync_test def test_bookmark_transaction_decorator(): - for p in APerson.nodes: - p.delete() - # should work result, bookmarks = in_a_tx_with_bookmark("Ruth", bookmarks=None) assert result is None @@ -181,9 +170,6 @@ def test_bookmark_passed_in_to_context(spy_on_db_begin): @mark_sync_test def test_query_inside_bookmark_transaction(): - for p in APerson.nodes: - p.delete() - with db.transaction as transaction: APerson(name="Alice").save() APerson(name="Bob").save() From 2b014badc1a098087dc600d17350569a8f4190fa Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Mon, 4 Nov 2024 17:14:14 +0100 Subject: [PATCH 37/42] Fixed test --- test/async_/test_issue283.py | 6 +++++- test/sync_/test_issue283.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/test/async_/test_issue283.py b/test/async_/test_issue283.py index 6682b765..ddbd6808 100644 --- a/test/async_/test_issue283.py +++ b/test/async_/test_issue283.py @@ -484,6 +484,10 @@ async def test_resolve_inexistent_relationship(): Attempting to resolve an inexistent relationship should raise an exception :return: """ + A = await TechnicalPerson(name="Michael Knight", expertise="Cars").save() + B = await TechnicalPerson(name="Luke Duke", expertise="Lasers").save() + + await A.friends_with.connect(B) # Forget about the FRIENDS_WITH Relationship. del adb._NODE_CLASS_REGISTRY[frozenset(["FRIENDS_WITH"])] @@ -493,7 +497,7 @@ async def test_resolve_inexistent_relationship(): match=r"Relationship of type .* does not resolve to any of the known objects.*", ): query_data = await adb.cypher_query( - "MATCH (:ExtendedSomePerson)-[r:FRIENDS_WITH]->(:ExtendedSomePerson) " + "MATCH (:TechnicalPerson)-[r:FRIENDS_WITH]->(:TechnicalPerson) " "RETURN DISTINCT r", resolve_objects=True, ) diff --git a/test/sync_/test_issue283.py b/test/sync_/test_issue283.py index 842e21e5..fab4f0d7 100644 --- a/test/sync_/test_issue283.py +++ b/test/sync_/test_issue283.py @@ -445,6 +445,10 @@ def test_resolve_inexistent_relationship(): Attempting to resolve an inexistent relationship should raise an exception :return: """ + A = TechnicalPerson(name="Michael Knight", expertise="Cars").save() + B = TechnicalPerson(name="Luke Duke", expertise="Lasers").save() + + A.friends_with.connect(B) # Forget about the FRIENDS_WITH Relationship. del db._NODE_CLASS_REGISTRY[frozenset(["FRIENDS_WITH"])] @@ -454,7 +458,7 @@ def test_resolve_inexistent_relationship(): match=r"Relationship of type .* does not resolve to any of the known objects.*", ): query_data = db.cypher_query( - "MATCH (:ExtendedSomePerson)-[r:FRIENDS_WITH]->(:ExtendedSomePerson) " + "MATCH (:TechnicalPerson)-[r:FRIENDS_WITH]->(:TechnicalPerson) " "RETURN DISTINCT r", resolve_objects=True, ) From cad731d973e115f31203a39d1d1130d61c6ddf06 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Mon, 4 Nov 2024 17:39:51 +0100 Subject: [PATCH 38/42] Does it reduce complexity? --- neomodel/async_/match.py | 10 +++++----- neomodel/sync_/match.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 8f103c4f..6a2f9933 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -1505,11 +1505,11 @@ async def resolve_subgraph(self) -> list: for name, node in row.items(): if node.__class__ is self.source and "_" not in name: root_node = node - else: - if isinstance(node, list) and isinstance(node[0], list): - other_nodes[name] = node[0] - else: - other_nodes[name] = node + continue + if isinstance(node, list) and isinstance(node[0], list): + other_nodes[name] = node[0] + continue + other_nodes[name] = node results.append( self._to_subgraph(root_node, other_nodes, qbuilder._ast.subgraph) ) diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 1bdfe660..966b2601 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -1505,11 +1505,11 @@ def resolve_subgraph(self) -> list: for name, node in row.items(): if node.__class__ is self.source and "_" not in name: root_node = node - else: - if isinstance(node, list) and isinstance(node[0], list): - other_nodes[name] = node[0] - else: - other_nodes[name] = node + continue + if isinstance(node, list) and isinstance(node[0], list): + other_nodes[name] = node[0] + continue + other_nodes[name] = node results.append( self._to_subgraph(root_node, other_nodes, qbuilder._ast.subgraph) ) From ce74a54e5d817d61f6e097d1c9b8dc32b86c0785 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 5 Nov 2024 09:36:31 +0100 Subject: [PATCH 39/42] Fix tests --- test/async_/test_properties.py | 6 +++++- test/sync_/test_properties.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/test/async_/test_properties.py b/test/async_/test_properties.py index 6949372d..d90f7d4e 100644 --- a/test/async_/test_properties.py +++ b/test/async_/test_properties.py @@ -309,6 +309,10 @@ def test_json(): assert prop.deflate(value) == '{"test": [1, 2, 3]}' assert prop.inflate('{"test": [1, 2, 3]}') == value + value_with_unicode = {"test": [1, 2, 3, "©"]} + assert prop.deflate(value_with_unicode) == '{"test": [1, 2, 3, "\\u00a9"]}' + assert prop.inflate('{"test": [1, 2, 3, "\\u00a9"]}') == value_with_unicode + def test_json_unicode(): prop = JSONProperty(ensure_ascii=False) @@ -318,7 +322,7 @@ def test_json_unicode(): value = {"test": [1, 2, 3, "©"]} assert prop.deflate(value) == '{"test": [1, 2, 3, "©"]}' - assert prop.inflate('{"test": [1, 2, 3, ©]}') == value + assert prop.inflate('{"test": [1, 2, 3, "©"]}') == value def test_indexed(): diff --git a/test/sync_/test_properties.py b/test/sync_/test_properties.py index b61182fb..c00f3bdd 100644 --- a/test/sync_/test_properties.py +++ b/test/sync_/test_properties.py @@ -303,6 +303,10 @@ def test_json(): assert prop.deflate(value) == '{"test": [1, 2, 3]}' assert prop.inflate('{"test": [1, 2, 3]}') == value + value_with_unicode = {"test": [1, 2, 3, "©"]} + assert prop.deflate(value_with_unicode) == '{"test": [1, 2, 3, "\\u00a9"]}' + assert prop.inflate('{"test": [1, 2, 3, "\\u00a9"]}') == value_with_unicode + def test_json_unicode(): prop = JSONProperty(ensure_ascii=False) @@ -312,7 +316,7 @@ def test_json_unicode(): value = {"test": [1, 2, 3, "©"]} assert prop.deflate(value) == '{"test": [1, 2, 3, "©"]}' - assert prop.inflate('{"test": [1, 2, 3, ©]}') == value + assert prop.inflate('{"test": [1, 2, 3, "©"]}') == value def test_indexed(): From 0007c5770fb0d94afab3ccf07f678d4d162d89ca Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 5 Nov 2024 10:17:45 +0100 Subject: [PATCH 40/42] Add 3.13 and drop 3.8 in tests --- .github/workflows/integration-tests.yml | 2 +- README.md | 2 +- requirements-dev.txt | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 25b26cda..1a0450f7 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.12", "3.11", "3.10", "3.9", "3.8"] + python-version: ["3.13", "3.12", "3.11", "3.10", "3.9"] neo4j-version: ["community", "enterprise", "5.5-enterprise", "4.4-enterprise", "4.4-community"] steps: diff --git a/README.md b/README.md index 633cb2f1..07626dff 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ Ensure `dbms.security.auth_enabled=true` in your database configuration file. Setup a virtual environment, install neomodel for development and run the test suite: : - $ pip install -e '.[dev,pandas,numpy]' + $ pip install -r requirements-dev.txt $ pytest The tests in \"test_connection.py\" will fail locally if you don\'t diff --git a/requirements-dev.txt b/requirements-dev.txt index bf3fa116..446dd8c1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,6 +3,7 @@ unasync>=0.5.0 pytest>=7.1 +pytest-asyncio>=0.19.0 pytest-cov>=4.0 pre-commit black From d3959f2e957b55135ac0b6a73311ae4499aec90f Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 5 Nov 2024 10:45:15 +0100 Subject: [PATCH 41/42] Update Changelog and README --- Changelog | 11 +++++++++++ README.md | 10 +++++++++- pyproject.toml | 4 ++-- requirements.txt | 2 +- 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/Changelog b/Changelog index b65b90a7..7dbbde03 100644 --- a/Changelog +++ b/Changelog @@ -1,3 +1,14 @@ +Version 5.4.0 2024-11 +* Traversal option for filtering and ordering +* Insert raw Cypher for ordering +* Possibility to traverse relations, only returning the last element of the path +* Resolve the results of complex queries as a nested subgraph +* Possibility to transform variables, with aggregations methods : Collect() and Last() +* Intermediate transform, for example to order variables before collecting +* Subqueries (Cypher CALL{} clause) +* Allow JSONProperty to actually use non-ascii elements +* Bumped neo4j (driver) to 5.26.0 + Version 5.3.3 2024-09 * Fixes vector index doc and test diff --git a/README.md b/README.md index 07626dff..f72d5a41 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ GitHub repo found at . **For neomodel releases 5.x :** -- Python 3.7+ +- Python 3.8+ - Neo4j 5.x, 4.4 (LTS) **For neomodel releases 4.x :** @@ -37,6 +37,14 @@ GitHub repo found at . Available on [readthedocs](http://neomodel.readthedocs.org). +# New in 5.4.0 + +This version adds many new features, expanding neomodel's querying capabilities. Those features were kindly contributed back by the [OpenStudyBuilder team](https://openstudybuilder.com/). A VERY special thanks to @tonioo for the integration work. + +There are too many new capabilities here, so I advise you to start by looking at the full summary example in the [Getting Started guide](https://neomodel.readthedocs.io/en/latest/getting_started.html#full-example). It will then point you to the various relevant sections. + +We also validated support for [Python 3.13](https://docs.python.org/3/whatsnew/3.13.html). + # New in 5.3.0 neomodel now supports asynchronous programming, thanks to the [Neo4j driver async API](https://neo4j.com/docs/api/python-driver/current/async_api.html). The [documentation](http://neomodel.readthedocs.org) has been updated accordingly, with an updated getting started section, and some specific documentation for the async API. diff --git a/pyproject.toml b/pyproject.toml index ce83e943..d72c546b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,9 +22,9 @@ classifiers = [ "Topic :: Database", ] dependencies = [ - "neo4j~=5.19.0", + "neo4j~=5.26.0", ] -requires-python = ">=3.7" +requires-python = ">=3.8" dynamic = ["version"] [project.urls] diff --git a/requirements.txt b/requirements.txt index ffbfe285..e7a3f522 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -neo4j~=5.19.0 +neo4j~=5.26.0 From d7747e7e24d39fa8f281c731faa1529f943d86a7 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 5 Nov 2024 10:47:41 +0100 Subject: [PATCH 42/42] Update changelog --- Changelog | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Changelog b/Changelog index 7dbbde03..079cab72 100644 --- a/Changelog +++ b/Changelog @@ -6,8 +6,9 @@ Version 5.4.0 2024-11 * Possibility to transform variables, with aggregations methods : Collect() and Last() * Intermediate transform, for example to order variables before collecting * Subqueries (Cypher CALL{} clause) -* Allow JSONProperty to actually use non-ascii elements +* Allow JSONProperty to actually use non-ascii elements. Thanks to @danikirish * Bumped neo4j (driver) to 5.26.0 +* Special huge thanks to @tonioo for this release Version 5.3.3 2024-09 * Fixes vector index doc and test