From 8c7d4a9bbdc1eb64a856d8313b33d93da6d92fb6 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Mon, 23 Sep 2024 16:19:33 +0200 Subject: [PATCH 1/4] Added method to resolve a subgraph from a fetch_relations() call. --- docker-scripts/docker-neo4j.sh | 3 +- neomodel/async_/match.py | 87 +++++++++++++++++++++++++++++++-- neomodel/sync_/match.py | 89 ++++++++++++++++++++++++++++++++-- test/async_/test_match_api.py | 34 +++++++++++++ test/sync_/test_match_api.py | 34 +++++++++++++ 5 files changed, 240 insertions(+), 7 deletions(-) diff --git a/docker-scripts/docker-neo4j.sh b/docker-scripts/docker-neo4j.sh index 99aabfff..6b146c95 100644 --- a/docker-scripts/docker-neo4j.sh +++ b/docker-scripts/docker-neo4j.sh @@ -5,4 +5,5 @@ docker run \ --env NEO4J_AUTH=neo4j/foobarbaz \ --env NEO4J_ACCEPT_LICENSE_AGREEMENT=yes \ --env NEO4JLABS_PLUGINS='["apoc"]' \ - neo4j:$1 \ No newline at end of file + --rm \ + neo4j:$1 diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index c203dcc2..2796f504 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -5,6 +5,7 @@ from typing import Any, Optional from neomodel.async_.core import AsyncStructuredNode, adb +from neomodel.async_.relationship import AsyncStructuredRel from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase from neomodel.properties import AliasProperty, ArrayProperty @@ -414,16 +415,18 @@ def __init__( self.lookup = lookup self.additional_return = additional_return if additional_return else [] self.is_count = is_count + self.subgraph: dict = {} class AsyncQueryBuilder: - def __init__(self, node_set): + def __init__(self, node_set, with_subgraph: bool = False): self.node_set = node_set self._ast = QueryAST() self._query_params = {} self._place_holder_registry = {} self._ident_count = 0 self._node_counters = defaultdict(int) + self._with_subgraph: bool = with_subgraph async def build_ast(self): if hasattr(self.node_set, "relations_to_fetch"): @@ -516,6 +519,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: stmt: str = "" source_class_iterator = source_class parts = path.split("__") + if self._with_subgraph: + subgraph = self._ast.subgraph for index, part in enumerate(parts): relationship = getattr(source_class_iterator, part) # build source @@ -549,6 +554,13 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: lhs_ident = stmt rel_ident = self.create_ident() + if self._with_subgraph and part not in self._ast.subgraph: + subgraph[part] = { + "target": relationship.definition["node_class"], + "children": {}, + "variable_name": rhs_name, + "rel_variable_name": rel_ident, + } if relation["include_in_return"]: self._additional_return(rel_ident) stmt = _rel_helper( @@ -559,6 +571,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: relation_type=relationship.definition["relation_type"], ) source_class_iterator = relationship.definition["node_class"] + if self._with_subgraph: + subgraph = subgraph[part]["children"] if relation.get("optional"): self._ast.optional_match.append(stmt) @@ -778,7 +792,7 @@ async def _contains(self, node_element_id): self._query_params[place_holder] = node_element_id return await self._count() >= 1 - async def _execute(self, lazy=False): + async def _execute(self, lazy: bool = False, dict_output: bool = False): if lazy: # inject id() into return or return_set if self._ast.return_clause: @@ -791,9 +805,13 @@ async def _execute(self, lazy=False): for item in self._ast.additional_return ] query = self.build_query() - results, _ = await adb.cypher_query( + results, prop_names = await adb.cypher_query( query, self._query_params, resolve_objects=True ) + if dict_output: + for item in results: + yield dict(zip(prop_names, item)) + return # The following is not as elegant as it could be but had to be copied from the # version prior to cypher_query with the resolve_objects capability. # It seems that certain calls are only supposed to be focusing to the first @@ -1146,6 +1164,69 @@ def register_extra_var(vardef, varname: str = None): return self + def _to_subgraph(self, root_node, other_nodes, subgraph): + """Recursive method to build root_node's relation graph from subgraph.""" + root_node._relations = {} + for name, relation_def in subgraph.items(): + for var_name, node in other_nodes.items(): + if ( + var_name + not in [ + relation_def["variable_name"], + relation_def["rel_variable_name"], + ] + or node is None + ): + continue + if isinstance(node, list): + if len(node) > 0 and isinstance(node[0], AsyncStructuredRel): + name += "_relationship" + root_node._relations[name] = [] + for item in node: + root_node._relations[name].append( + self._to_subgraph( + item, other_nodes, relation_def["children"] + ) + ) + else: + if isinstance(node, AsyncStructuredRel): + name += "_relationship" + root_node._relations[name] = self._to_subgraph( + node, other_nodes, relation_def["children"] + ) + + return root_node + + async def resolve_subgraph(self) -> list: + """ + Convert every result contained in this node set to a subgraph. + + By default, we receive results from neomodel as a list of + nodes without the hierarchy. This method tries to rebuild this + hierarchy without overriding anything in the node, that's why + we use a dedicated property to store node's relations. + + """ + results: list = [] + qbuilder = self.query_cls(self, with_subgraph=True) + await qbuilder.build_ast() + all_nodes = qbuilder._execute(dict_output=True) + other_nodes = {} + root_node = None + async for row in all_nodes: + for name, node in row.items(): + if node.__class__ is self.source and "_" not in name: + root_node = node + else: + if isinstance(node, list) and isinstance(node[0], list): + other_nodes[name] = node[0] + else: + other_nodes[name] = node + results.append( + self._to_subgraph(root_node, other_nodes, qbuilder._ast.subgraph) + ) + return results + class AsyncTraversal(AsyncBaseSet): """ diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 8f163c53..8221c979 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -8,6 +8,7 @@ from neomodel.match_q import Q, QBase from neomodel.properties import AliasProperty, ArrayProperty from neomodel.sync_.core import StructuredNode, db +from neomodel.sync_.relationship import StructuredRel from neomodel.util import INCOMING, OUTGOING @@ -414,16 +415,18 @@ def __init__( self.lookup = lookup self.additional_return = additional_return if additional_return else [] self.is_count = is_count + self.subgraph: dict = {} class QueryBuilder: - def __init__(self, node_set): + def __init__(self, node_set, with_subgraph: bool = False): self.node_set = node_set self._ast = QueryAST() self._query_params = {} self._place_holder_registry = {} self._ident_count = 0 self._node_counters = defaultdict(int) + self._with_subgraph: bool = with_subgraph def build_ast(self): if hasattr(self.node_set, "relations_to_fetch"): @@ -516,6 +519,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: stmt: str = "" source_class_iterator = source_class parts = path.split("__") + if self._with_subgraph: + subgraph = self._ast.subgraph for index, part in enumerate(parts): relationship = getattr(source_class_iterator, part) # build source @@ -549,6 +554,13 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: lhs_ident = stmt rel_ident = self.create_ident() + if self._with_subgraph and part not in self._ast.subgraph: + subgraph[part] = { + "target": relationship.definition["node_class"], + "children": {}, + "variable_name": rhs_name, + "rel_variable_name": rel_ident, + } if relation["include_in_return"]: self._additional_return(rel_ident) stmt = _rel_helper( @@ -559,6 +571,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: relation_type=relationship.definition["relation_type"], ) source_class_iterator = relationship.definition["node_class"] + if self._with_subgraph: + subgraph = subgraph[part]["children"] if relation.get("optional"): self._ast.optional_match.append(stmt) @@ -776,7 +790,7 @@ def _contains(self, node_element_id): self._query_params[place_holder] = node_element_id return self._count() >= 1 - def _execute(self, lazy=False): + def _execute(self, lazy: bool = False, dict_output: bool = False): if lazy: # inject id() into return or return_set if self._ast.return_clause: @@ -789,7 +803,13 @@ def _execute(self, lazy=False): for item in self._ast.additional_return ] query = self.build_query() - results, _ = db.cypher_query(query, self._query_params, resolve_objects=True) + results, prop_names = db.cypher_query( + query, self._query_params, resolve_objects=True + ) + if dict_output: + for item in results: + yield dict(zip(prop_names, item)) + return # The following is not as elegant as it could be but had to be copied from the # version prior to cypher_query with the resolve_objects capability. # It seems that certain calls are only supposed to be focusing to the first @@ -1142,6 +1162,69 @@ def register_extra_var(vardef, varname: str = None): return self + def _to_subgraph(self, root_node, other_nodes, subgraph): + """Recursive method to build root_node's relation graph from subgraph.""" + root_node._relations = {} + for name, relation_def in subgraph.items(): + for var_name, node in other_nodes.items(): + if ( + var_name + not in [ + relation_def["variable_name"], + relation_def["rel_variable_name"], + ] + or node is None + ): + continue + if isinstance(node, list): + if len(node) > 0 and isinstance(node[0], StructuredRel): + name += "_relationship" + root_node._relations[name] = [] + for item in node: + root_node._relations[name].append( + self._to_subgraph( + item, other_nodes, relation_def["children"] + ) + ) + else: + if isinstance(node, StructuredRel): + name += "_relationship" + root_node._relations[name] = self._to_subgraph( + node, other_nodes, relation_def["children"] + ) + + return root_node + + def resolve_subgraph(self) -> list: + """ + Convert every result contained in this node set to a subgraph. + + By default, we receive results from neomodel as a list of + nodes without the hierarchy. This method tries to rebuild this + hierarchy without overriding anything in the node, that's why + we use a dedicated property to store node's relations. + + """ + results: list = [] + qbuilder = self.query_cls(self, with_subgraph=True) + qbuilder.build_ast() + all_nodes = qbuilder._execute(dict_output=True) + other_nodes = {} + root_node = None + for row in all_nodes: + for name, node in row.items(): + if node.__class__ is self.source and "_" not in name: + root_node = node + else: + if isinstance(node, list) and isinstance(node[0], list): + other_nodes[name] = node[0] + else: + other_nodes[name] = node + results.append( + self._to_subgraph(root_node, other_nodes, qbuilder._ast.subgraph) + ) + return results + class Traversal(BaseSet): """ diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 2175d604..ff8261ef 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -638,6 +638,40 @@ async def test_annotate_and_collect(): assert len(result[0][1][0]) == 2 # 2 species must be there +@mark_async_test +async def test_resolve_subgraph(): + # Clean DB before we start anything... + await adb.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = await Species(name="Arabica").save() + robusta = await Species(name="Robusta").save() + nescafe = await Coffee(name="Nescafe", price=99).save() + nescafe_gold = await Coffee(name="Nescafe Gold", price=11).save() + + tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + await nescafe.suppliers.connect(tesco) + await nescafe_gold.suppliers.connect(tesco) + await nescafe.species.connect(arabica) + await nescafe_gold.species.connect(robusta) + + result = await Supplier.nodes.fetch_relations("coffees__species").resolve_subgraph() + assert len(result) == 2 + + assert hasattr(result[0], "_relations") + assert "coffees" in result[0]._relations + coffees = result[0]._relations["coffees"] + assert hasattr(coffees, "_relations") + assert "species" in coffees._relations + assert robusta == coffees._relations["species"] + + assert hasattr(result[1], "_relations") + assert "coffees" in result[1]._relations + coffees = result[1]._relations["coffees"] + assert hasattr(coffees, "_relations") + assert "species" in coffees._relations + assert arabica == coffees._relations["species"] + + @mark_async_test async def test_issue_795(): jim = await PersonX(name="Jim", age=3).save() # Create diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index d22e4100..5aecf3cf 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -626,6 +626,40 @@ def test_annotate_and_collect(): assert len(result[0][1][0]) == 2 # 2 species must be there +@mark_sync_test +def test_resolve_subgraph(): + # Clean DB before we start anything... + db.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = Species(name="Arabica").save() + robusta = Species(name="Robusta").save() + nescafe = Coffee(name="Nescafe", price=99).save() + nescafe_gold = Coffee(name="Nescafe Gold", price=11).save() + + tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + nescafe.suppliers.connect(tesco) + nescafe_gold.suppliers.connect(tesco) + nescafe.species.connect(arabica) + nescafe_gold.species.connect(robusta) + + result = Supplier.nodes.fetch_relations("coffees__species").resolve_subgraph() + assert len(result) == 2 + + assert hasattr(result[0], "_relations") + assert "coffees" in result[0]._relations + coffees = result[0]._relations["coffees"] + assert hasattr(coffees, "_relations") + assert "species" in coffees._relations + assert robusta == coffees._relations["species"] + + assert hasattr(result[1], "_relations") + assert "coffees" in result[1]._relations + coffees = result[1]._relations["coffees"] + assert hasattr(coffees, "_relations") + assert "species" in coffees._relations + assert arabica == coffees._relations["species"] + + @mark_sync_test def test_issue_795(): jim = PersonX(name="Jim", age=3).save() # Create From fefabce3fbf795d942e7bbf08c08ac493c8d9184 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Mon, 23 Sep 2024 17:26:47 +0200 Subject: [PATCH 2/4] Added constraints to limit use cases of resolve_subgraph() --- neomodel/async_/match.py | 61 +++++++++++++++++++++-------------- neomodel/sync_/match.py | 61 +++++++++++++++++++++-------------- test/async_/test_match_api.py | 19 +++++++++++ test/sync_/test_match_api.py | 19 +++++++++++ 4 files changed, 110 insertions(+), 50 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 2796f504..03033415 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -1114,39 +1114,42 @@ def order_by(self, *props): return self - def _add_relations( - self, *relation_names, include_in_return=True, **aliased_relation_names + def _register_relation_to_fetch( + self, relation_def: Any, alias: str = None, include_in_return: bool = True ): - """Specify a set of relations to return.""" - relations = [] - - def register_relation_to_fetch(relation_def: Any, alias: str = None): - if isinstance(relation_def, Optional): - item = {"path": relation_def.relation, "optional": True} - else: - item = {"path": relation_def} - item["include_in_return"] = include_in_return - if alias: - item["alias"] = alias - relations.append(item) + if isinstance(relation_def, Optional): + item = {"path": relation_def.relation, "optional": True} + else: + item = {"path": relation_def} + item["include_in_return"] = include_in_return + if alias: + item["alias"] = alias + return item + def fetch_relations(self, *relation_names): + """Specify a set of relations to traverse and return.""" + relations = [] for relation_name in relation_names: - register_relation_to_fetch(relation_name) - for alias, relation_def in aliased_relation_names.items(): - register_relation_to_fetch(relation_def, alias) - + relations.append(self._register_relation_to_fetch(relation_name)) self.relations_to_fetch = relations return self - def fetch_relations(self, *relation_names, **aliased_relation_names): - """Specify a set of relations to traverse and return.""" - return self._add_relations(*relation_names, **aliased_relation_names) - def traverse_relations(self, *relation_names, **aliased_relation_names): """Specify a set of relations to traverse only.""" - return self._add_relations( - *relation_names, include_in_return=False, **aliased_relation_names - ) + relations = [] + for relation_name in relation_names: + relations.append( + self._register_relation_to_fetch(relation_name, include_in_return=False) + ) + for alias, relation_def in aliased_relation_names.items(): + relations.append( + self._register_relation_to_fetch( + relation_def, alias, include_in_return=False + ) + ) + + self.relations_to_fetch = relations + return self def annotate(self, *vars, **aliased_vars): """Annotate node set results with extra variables.""" @@ -1207,6 +1210,14 @@ async def resolve_subgraph(self) -> list: we use a dedicated property to store node's relations. """ + if not self.relations_to_fetch: + raise RuntimeError( + "Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()." + ) + if not self.relations_to_fetch[0]["include_in_return"]: + raise NotImplementedError( + "You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead." + ) results: list = [] qbuilder = self.query_cls(self, with_subgraph=True) await qbuilder.build_ast() diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 8221c979..0f8b044e 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -1112,39 +1112,42 @@ def order_by(self, *props): return self - def _add_relations( - self, *relation_names, include_in_return=True, **aliased_relation_names + def _register_relation_to_fetch( + self, relation_def: Any, alias: str = None, include_in_return: bool = True ): - """Specify a set of relations to return.""" - relations = [] - - def register_relation_to_fetch(relation_def: Any, alias: str = None): - if isinstance(relation_def, Optional): - item = {"path": relation_def.relation, "optional": True} - else: - item = {"path": relation_def} - item["include_in_return"] = include_in_return - if alias: - item["alias"] = alias - relations.append(item) + if isinstance(relation_def, Optional): + item = {"path": relation_def.relation, "optional": True} + else: + item = {"path": relation_def} + item["include_in_return"] = include_in_return + if alias: + item["alias"] = alias + return item + def fetch_relations(self, *relation_names): + """Specify a set of relations to traverse and return.""" + relations = [] for relation_name in relation_names: - register_relation_to_fetch(relation_name) - for alias, relation_def in aliased_relation_names.items(): - register_relation_to_fetch(relation_def, alias) - + relations.append(self._register_relation_to_fetch(relation_name)) self.relations_to_fetch = relations return self - def fetch_relations(self, *relation_names, **aliased_relation_names): - """Specify a set of relations to traverse and return.""" - return self._add_relations(*relation_names, **aliased_relation_names) - def traverse_relations(self, *relation_names, **aliased_relation_names): """Specify a set of relations to traverse only.""" - return self._add_relations( - *relation_names, include_in_return=False, **aliased_relation_names - ) + relations = [] + for relation_name in relation_names: + relations.append( + self._register_relation_to_fetch(relation_name, include_in_return=False) + ) + for alias, relation_def in aliased_relation_names.items(): + relations.append( + self._register_relation_to_fetch( + relation_def, alias, include_in_return=False + ) + ) + + self.relations_to_fetch = relations + return self def annotate(self, *vars, **aliased_vars): """Annotate node set results with extra variables.""" @@ -1205,6 +1208,14 @@ def resolve_subgraph(self) -> list: we use a dedicated property to store node's relations. """ + if not self.relations_to_fetch: + raise RuntimeError( + "Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()." + ) + if not self.relations_to_fetch[0]["include_in_return"]: + raise NotImplementedError( + "You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead." + ) results: list = [] qbuilder = self.query_cls(self, with_subgraph=True) qbuilder.build_ast() diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index ff8261ef..216fbe7e 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -1,3 +1,4 @@ +import re from datetime import datetime from test._async_compat import mark_async_test @@ -654,6 +655,24 @@ async def test_resolve_subgraph(): await nescafe.species.connect(arabica) await nescafe_gold.species.connect(robusta) + with raises( + RuntimeError, + match=re.escape( + "Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()." + ), + ): + result = await Supplier.nodes.resolve_subgraph() + + with raises( + NotImplementedError, + match=re.escape( + "You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead." + ), + ): + result = await Supplier.nodes.traverse_relations( + "coffees__species" + ).resolve_subgraph() + result = await Supplier.nodes.fetch_relations("coffees__species").resolve_subgraph() assert len(result) == 2 diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 5aecf3cf..d015776a 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -1,3 +1,4 @@ +import re from datetime import datetime from test._async_compat import mark_sync_test @@ -642,6 +643,24 @@ def test_resolve_subgraph(): nescafe.species.connect(arabica) nescafe_gold.species.connect(robusta) + with raises( + RuntimeError, + match=re.escape( + "Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()." + ), + ): + result = Supplier.nodes.resolve_subgraph() + + with raises( + NotImplementedError, + match=re.escape( + "You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead." + ), + ): + result = Supplier.nodes.traverse_relations( + "coffees__species" + ).resolve_subgraph() + result = Supplier.nodes.fetch_relations("coffees__species").resolve_subgraph() assert len(result) == 2 From 065f1bd230ff91e0408a03b222b3f86ab6f9102b Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Mon, 23 Sep 2024 17:32:49 +0200 Subject: [PATCH 3/4] Cannot ensure output order so make sure tests are running. --- test/async_/test_match_api.py | 2 -- test/sync_/test_match_api.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 216fbe7e..1c61589d 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -681,14 +681,12 @@ async def test_resolve_subgraph(): coffees = result[0]._relations["coffees"] assert hasattr(coffees, "_relations") assert "species" in coffees._relations - assert robusta == coffees._relations["species"] assert hasattr(result[1], "_relations") assert "coffees" in result[1]._relations coffees = result[1]._relations["coffees"] assert hasattr(coffees, "_relations") assert "species" in coffees._relations - assert arabica == coffees._relations["species"] @mark_async_test diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index d015776a..9f24afbf 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -669,14 +669,12 @@ def test_resolve_subgraph(): coffees = result[0]._relations["coffees"] assert hasattr(coffees, "_relations") assert "species" in coffees._relations - assert robusta == coffees._relations["species"] assert hasattr(result[1], "_relations") assert "coffees" in result[1]._relations coffees = result[1]._relations["coffees"] assert hasattr(coffees, "_relations") assert "species" in coffees._relations - assert arabica == coffees._relations["species"] @mark_sync_test From 40a60fbe4f325fe3f8cb6f963897fba51eb8f821 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Mon, 23 Sep 2024 17:36:30 +0200 Subject: [PATCH 4/4] Added test to cover use case with Optional match. --- test/async_/test_match_api.py | 27 +++++++++++++++++++++++++++ test/sync_/test_match_api.py | 27 +++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 1c61589d..b3546ef9 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -689,6 +689,33 @@ async def test_resolve_subgraph(): assert "species" in coffees._relations +@mark_async_test +async def test_resolve_subgraph_optional(): + # Clean DB before we start anything... + await adb.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = await Species(name="Arabica").save() + nescafe = await Coffee(name="Nescafe", price=99).save() + nescafe_gold = await Coffee(name="Nescafe Gold", price=11).save() + + tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + await nescafe.suppliers.connect(tesco) + await nescafe_gold.suppliers.connect(tesco) + await nescafe.species.connect(arabica) + + result = await Supplier.nodes.fetch_relations( + Optional("coffees__species") + ).resolve_subgraph() + assert len(result) == 1 + + assert hasattr(result[0], "_relations") + assert "coffees" in result[0]._relations + coffees = result[0]._relations["coffees"] + assert hasattr(coffees, "_relations") + assert "species" in coffees._relations + assert coffees._relations["species"] == arabica + + @mark_async_test async def test_issue_795(): jim = await PersonX(name="Jim", age=3).save() # Create diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 9f24afbf..ab843639 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -677,6 +677,33 @@ def test_resolve_subgraph(): assert "species" in coffees._relations +@mark_sync_test +def test_resolve_subgraph_optional(): + # Clean DB before we start anything... + db.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = Species(name="Arabica").save() + nescafe = Coffee(name="Nescafe", price=99).save() + nescafe_gold = Coffee(name="Nescafe Gold", price=11).save() + + tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + nescafe.suppliers.connect(tesco) + nescafe_gold.suppliers.connect(tesco) + nescafe.species.connect(arabica) + + result = Supplier.nodes.fetch_relations( + Optional("coffees__species") + ).resolve_subgraph() + assert len(result) == 1 + + assert hasattr(result[0], "_relations") + assert "coffees" in result[0]._relations + coffees = result[0]._relations["coffees"] + assert hasattr(coffees, "_relations") + assert "species" in coffees._relations + assert coffees._relations["species"] == arabica + + @mark_sync_test def test_issue_795(): jim = PersonX(name="Jim", age=3).save() # Create