From 9114c2662ffff0a5fafe53b12d05bcf3c8d4ac4c Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Wed, 14 Aug 2024 15:39:56 +0200 Subject: [PATCH 1/5] FIx tests and doc for vector index --- doc/source/schema_management.rst | 5 +++- test/async_/test_label_install.py | 36 +++++++++++++++------------- test/sync_/test_label_install.py | 39 +++++++++++++++++-------------- 3 files changed, 45 insertions(+), 35 deletions(-) diff --git a/doc/source/schema_management.rst b/doc/source/schema_management.rst index 61ad94a3..9d0dce0a 100644 --- a/doc/source/schema_management.rst +++ b/doc/source/schema_management.rst @@ -72,6 +72,9 @@ Full example: :: name = StringProperty( index=True, fulltext_index=FulltextIndex(analyzer='english', eventually_consistent=True) + ) + name_embedding = ArrayProperty( + FloatProperty(), vector_index=VectorIndex(dimensions=512, similarity_function='euclidean') ) @@ -83,7 +86,7 @@ The following constraints are supported: - ``unique_index=True``: This will create a uniqueness constraint on the property. Available for both nodes and relationships (Neo4j version 5.7 or higher). .. note:: - The uniquess constraint of Neo4j is not supported as such, but using ``required=True`` on a property serves the same purpose. + The uniqueness constraint of Neo4j is not supported as such, but using ``required=True`` on a property serves the same purpose. Extracting the schema from a database diff --git a/test/async_/test_label_install.py b/test/async_/test_label_install.py index 832578f5..03d6e87b 100644 --- a/test/async_/test_label_install.py +++ b/test/async_/test_label_install.py @@ -6,9 +6,11 @@ from neo4j.exceptions import ClientError from neomodel import ( + ArrayProperty, AsyncRelationshipTo, AsyncStructuredNode, AsyncStructuredRel, + FloatProperty, FulltextIndex, StringProperty, UniqueIdProperty, @@ -317,16 +319,17 @@ async def test_vector_index(): pytest.skip("Not supported before 5.15") class VectorIndexNode(AsyncStructuredNode): - name = StringProperty( - vector_index=VectorIndex(dimensions=256, similarity_function="euclidean") + embedding = ArrayProperty( + FloatProperty(), + vector_index=VectorIndex(dimensions=256, similarity_function="euclidean"), ) await adb.install_labels(VectorIndexNode) indexes = await adb.list_indexes() index_names = [index["name"] for index in indexes] - assert "vector_index_VectorIndexNode_name" in index_names + assert "vector_index_VectorIndexNode_embedding" in index_names - await adb.cypher_query("DROP INDEX vector_index_VectorIndexNode_name") + await adb.cypher_query("DROP INDEX vector_index_VectorIndexNode_embedding") @mark_async_test @@ -338,11 +341,11 @@ async def test_vector_index_conflict(): with patch("sys.stdout", new=stream): await adb.cypher_query( - "CREATE VECTOR INDEX FOR (n:VectorIndexNodeConflict) ON n.name OPTIONS{indexConfig:{`vector.similarity_function`:'cosine', `vector.dimensions`:1536}}" + "CREATE VECTOR INDEX FOR (n:VectorIndexNodeConflict) ON n.embedding OPTIONS{indexConfig:{`vector.similarity_function`:'cosine', `vector.dimensions`:1536}}" ) class VectorIndexNodeConflict(AsyncStructuredNode): - name = StringProperty(vector_index=VectorIndex()) + embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex()) await adb.install_labels(VectorIndexNodeConflict, quiet=False) @@ -361,7 +364,7 @@ async def test_vector_index_not_supported(): ): class VectorIndexNodeOld(AsyncStructuredNode): - name = StringProperty(vector_index=VectorIndex()) + embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex()) await adb.install_labels(VectorIndexNodeOld) @@ -372,8 +375,9 @@ async def test_rel_vector_index(): pytest.skip("Not supported before 5.18") class VectorIndexRel(AsyncStructuredRel): - name = StringProperty( - vector_index=VectorIndex(dimensions=256, similarity_function="euclidean") + embedding = ArrayProperty( + FloatProperty(), + vector_index=VectorIndex(dimensions=256, similarity_function="euclidean"), ) class VectorIndexRelNode(AsyncStructuredNode): @@ -384,9 +388,9 @@ class VectorIndexRelNode(AsyncStructuredNode): await adb.install_labels(VectorIndexRelNode) indexes = await adb.list_indexes() index_names = [index["name"] for index in indexes] - assert "vector_index_VECTOR_INDEX_REL_name" in index_names + assert "vector_index_VECTOR_INDEX_REL_embedding" in index_names - await adb.cypher_query("DROP INDEX vector_index_VECTOR_INDEX_REL_name") + await adb.cypher_query("DROP INDEX vector_index_VECTOR_INDEX_REL_embedding") @mark_async_test @@ -398,11 +402,11 @@ async def test_rel_vector_index_conflict(): with patch("sys.stdout", new=stream): await adb.cypher_query( - "CREATE VECTOR INDEX FOR ()-[r:VECTOR_INDEX_REL_CONFLICT]-() ON r.name OPTIONS{indexConfig:{`vector.similarity_function`:'cosine', `vector.dimensions`:1536}}" + "CREATE VECTOR INDEX FOR ()-[r:VECTOR_INDEX_REL_CONFLICT]-() ON r.embedding OPTIONS{indexConfig:{`vector.similarity_function`:'cosine', `vector.dimensions`:1536}}" ) class VectorIndexRelConflict(AsyncStructuredRel): - name = StringProperty(vector_index=VectorIndex()) + embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex()) class VectorIndexRelConflictNode(AsyncStructuredNode): has_rel = AsyncRelationshipTo( @@ -428,7 +432,7 @@ async def test_rel_vector_index_not_supported(): ): class VectorIndexRelOld(AsyncStructuredRel): - name = StringProperty(vector_index=VectorIndex()) + embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex()) class VectorIndexRelOldNode(AsyncStructuredNode): has_rel = AsyncRelationshipTo( @@ -522,7 +526,7 @@ class UnauthorizedFulltextNode(AsyncStructuredNode): with await adb.impersonate(unauthorized_user): class UnauthorizedVectorNode(AsyncStructuredNode): - name = StringProperty(vector_index=VectorIndex()) + embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex()) await adb.install_labels(UnauthorizedVectorNode) @@ -572,7 +576,7 @@ class UnauthorizedFulltextRelNode(AsyncStructuredNode): with await adb.impersonate(unauthorized_user): class UnauthorizedVectorRel(AsyncStructuredRel): - name = StringProperty(vector_index=VectorIndex()) + embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex()) class UnauthorizedVectorRelNode(AsyncStructuredNode): has_rel = AsyncRelationshipTo( diff --git a/test/sync_/test_label_install.py b/test/sync_/test_label_install.py index 8b3c5b3a..ef18324b 100644 --- a/test/sync_/test_label_install.py +++ b/test/sync_/test_label_install.py @@ -6,6 +6,8 @@ from neo4j.exceptions import ClientError from neomodel import ( + ArrayProperty, + FloatProperty, FulltextIndex, RelationshipTo, StringProperty, @@ -26,8 +28,7 @@ class NodeWithConstraint(StructuredNode): name = StringProperty(unique_index=True) -class NodeWithRelationship(StructuredNode): - ... +class NodeWithRelationship(StructuredNode): ... class IndexedRelationship(StructuredRel): @@ -317,16 +318,17 @@ def test_vector_index(): pytest.skip("Not supported before 5.15") class VectorIndexNode(StructuredNode): - name = StringProperty( - vector_index=VectorIndex(dimensions=256, similarity_function="euclidean") + embedding = ArrayProperty( + FloatProperty(), + vector_index=VectorIndex(dimensions=256, similarity_function="euclidean"), ) db.install_labels(VectorIndexNode) indexes = db.list_indexes() index_names = [index["name"] for index in indexes] - assert "vector_index_VectorIndexNode_name" in index_names + assert "vector_index_VectorIndexNode_embedding" in index_names - db.cypher_query("DROP INDEX vector_index_VectorIndexNode_name") + db.cypher_query("DROP INDEX vector_index_VectorIndexNode_embedding") @mark_sync_test @@ -338,11 +340,11 @@ def test_vector_index_conflict(): with patch("sys.stdout", new=stream): db.cypher_query( - "CREATE VECTOR INDEX FOR (n:VectorIndexNodeConflict) ON n.name OPTIONS{indexConfig:{`vector.similarity_function`:'cosine', `vector.dimensions`:1536}}" + "CREATE VECTOR INDEX FOR (n:VectorIndexNodeConflict) ON n.embedding OPTIONS{indexConfig:{`vector.similarity_function`:'cosine', `vector.dimensions`:1536}}" ) class VectorIndexNodeConflict(StructuredNode): - name = StringProperty(vector_index=VectorIndex()) + embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex()) db.install_labels(VectorIndexNodeConflict, quiet=False) @@ -361,7 +363,7 @@ def test_vector_index_not_supported(): ): class VectorIndexNodeOld(StructuredNode): - name = StringProperty(vector_index=VectorIndex()) + embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex()) db.install_labels(VectorIndexNodeOld) @@ -372,8 +374,9 @@ def test_rel_vector_index(): pytest.skip("Not supported before 5.18") class VectorIndexRel(StructuredRel): - name = StringProperty( - vector_index=VectorIndex(dimensions=256, similarity_function="euclidean") + embedding = ArrayProperty( + FloatProperty(), + vector_index=VectorIndex(dimensions=256, similarity_function="euclidean"), ) class VectorIndexRelNode(StructuredNode): @@ -384,9 +387,9 @@ class VectorIndexRelNode(StructuredNode): db.install_labels(VectorIndexRelNode) indexes = db.list_indexes() index_names = [index["name"] for index in indexes] - assert "vector_index_VECTOR_INDEX_REL_name" in index_names + assert "vector_index_VECTOR_INDEX_REL_embedding" in index_names - db.cypher_query("DROP INDEX vector_index_VECTOR_INDEX_REL_name") + db.cypher_query("DROP INDEX vector_index_VECTOR_INDEX_REL_embedding") @mark_sync_test @@ -398,11 +401,11 @@ def test_rel_vector_index_conflict(): with patch("sys.stdout", new=stream): db.cypher_query( - "CREATE VECTOR INDEX FOR ()-[r:VECTOR_INDEX_REL_CONFLICT]-() ON r.name OPTIONS{indexConfig:{`vector.similarity_function`:'cosine', `vector.dimensions`:1536}}" + "CREATE VECTOR INDEX FOR ()-[r:VECTOR_INDEX_REL_CONFLICT]-() ON r.embedding OPTIONS{indexConfig:{`vector.similarity_function`:'cosine', `vector.dimensions`:1536}}" ) class VectorIndexRelConflict(StructuredRel): - name = StringProperty(vector_index=VectorIndex()) + embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex()) class VectorIndexRelConflictNode(StructuredNode): has_rel = RelationshipTo( @@ -428,7 +431,7 @@ def test_rel_vector_index_not_supported(): ): class VectorIndexRelOld(StructuredRel): - name = StringProperty(vector_index=VectorIndex()) + embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex()) class VectorIndexRelOldNode(StructuredNode): has_rel = RelationshipTo( @@ -520,7 +523,7 @@ class UnauthorizedFulltextNode(StructuredNode): with db.impersonate(unauthorized_user): class UnauthorizedVectorNode(StructuredNode): - name = StringProperty(vector_index=VectorIndex()) + embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex()) db.install_labels(UnauthorizedVectorNode) @@ -570,7 +573,7 @@ class UnauthorizedFulltextRelNode(StructuredNode): with db.impersonate(unauthorized_user): class UnauthorizedVectorRel(StructuredRel): - name = StringProperty(vector_index=VectorIndex()) + embedding = ArrayProperty(FloatProperty(), vector_index=VectorIndex()) class UnauthorizedVectorRelNode(StructuredNode): has_rel = RelationshipTo( From d78da5e79044438d231f3e69b626019a5eba42d7 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Wed, 14 Aug 2024 16:24:30 +0200 Subject: [PATCH 2/5] Fix test --- test/async_/test_match_api.py | 5 +++-- test/sync_/test_match_api.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index e3195448..4f549dd5 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -173,8 +173,9 @@ async def test_double_traverse(): results = [node async for node in qb._execute()] assert len(results) == 2 - assert results[0].name == "Decafe" - assert results[1].name == "Nescafe plus" + names = [n.name for n in results] + assert "Decafe" in names + assert "Nescafe plus" in names @mark_async_test diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 170a7363..d9c90bb9 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -164,8 +164,9 @@ def test_double_traverse(): results = [node for node in qb._execute()] assert len(results) == 2 - assert results[0].name == "Decafe" - assert results[1].name == "Nescafe plus" + names = [n.name for n in results] + assert "Decafe" in names + assert "Nescafe plus" in names @mark_sync_test From 02e96190568aa0207534485d8cd7d9bfa0fc5ebe Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Wed, 14 Aug 2024 16:38:58 +0200 Subject: [PATCH 3/5] Prepare rc branch --- Changelog | 3 +++ doc/source/configuration.rst | 2 +- neomodel/_version.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/Changelog b/Changelog index fe655005..b65b90a7 100644 --- a/Changelog +++ b/Changelog @@ -1,3 +1,6 @@ +Version 5.3.3 2024-09 +* Fixes vector index doc and test + Version 5.3.2 2024-06 * Add support for Vector and Fulltext indexes creation * Add DateTimeNeo4jFormatProperty for Neo4j native datetime format diff --git a/doc/source/configuration.rst b/doc/source/configuration.rst index e8c0a38b..e5d5d5d8 100644 --- a/doc/source/configuration.rst +++ b/doc/source/configuration.rst @@ -32,7 +32,7 @@ Adjust driver configuration - these options are only available for this connecti config.MAX_TRANSACTION_RETRY_TIME = 30.0 # default config.RESOLVER = None # default config.TRUST = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES # default - config.USER_AGENT = neomodel/v5.3.2 # default + config.USER_AGENT = neomodel/v5.3.3 # default Setting the database name, if different from the default one:: diff --git a/neomodel/_version.py b/neomodel/_version.py index 07f0e9e2..d2f4a6f4 100644 --- a/neomodel/_version.py +++ b/neomodel/_version.py @@ -1 +1 @@ -__version__ = "5.3.2" +__version__ = "5.3.3" From fcfd45db480f4f2a1204a5df805dd3e9f906b468 Mon Sep 17 00:00:00 2001 From: Christoph Brosch Date: Tue, 3 Sep 2024 09:24:19 +0200 Subject: [PATCH 4/5] Update getting_started.rst Fixed wrongly typed out argument in example command --- doc/source/getting_started.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/getting_started.rst b/doc/source/getting_started.rst index 2203b8a3..6e8a5aa0 100644 --- a/doc/source/getting_started.rst +++ b/doc/source/getting_started.rst @@ -79,7 +79,7 @@ Database Inspection - Requires APOC =================================== You can inspect an existing Neo4j database to generate a neomodel definition file using the ``inspect`` command:: - $ neomodel_inspect_database -db bolt://neo4j_username:neo4j_password@localhost:7687 --write-to yourapp/models.py + $ neomodel_inspect_database --db bolt://neo4j_username:neo4j_password@localhost:7687 --write-to yourapp/models.py This will generate a file called ``models.py`` in the ``yourapp`` directory. This file can be used as a starting point, and will contain the necessary module imports, as well as class definition for nodes and, if relevant, relationships. From 1dc380b8f1ca4eea825bf27a65da7fb3f298918e Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Fri, 20 Sep 2024 16:44:13 +0200 Subject: [PATCH 5/5] Added support for annotations and calling aggregating functions. Also introduced a new traverse_relations() method which only return the first hop of the path in the return clause. --- neomodel/async_/match.py | 112 ++++++++++++++++++++++++++++------ neomodel/sync_/match.py | 112 ++++++++++++++++++++++++++++------ test/async_/test_match_api.py | 42 +++++++++++++ test/sync_/test_match_api.py | 43 ++++++++++++- 4 files changed, 268 insertions(+), 41 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 7f0435fe..c203dcc2 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -2,7 +2,7 @@ import re from collections import defaultdict from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional from neomodel.async_.core import AsyncStructuredNode, adb from neomodel.exceptions import MultipleNodesReturned @@ -507,7 +507,7 @@ async def build_traversal(self, traversal): return traversal_ident - def _additional_return(self, name): + def _additional_return(self, name: str): if name not in self._ast.additional_return and name != self._ast.return_clause: self._ast.additional_return.append(name) @@ -515,7 +515,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: path: str = relation["path"] stmt: str = "" source_class_iterator = source_class - for index, part in enumerate(path.split("__")): + parts = path.split("__") + for index, part in enumerate(parts): relationship = getattr(source_class_iterator, part) # build source if "node_class" not in relationship.definition: @@ -523,11 +524,16 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: rhs_label = relationship.definition["node_class"].__label__ rel_reference = f'{relationship.definition["node_class"]}_{part}' self._node_counters[rel_reference] += 1 - rhs_name = ( - f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}" - ) + if index + 1 == len(parts) and "alias" in relation: + # If an alias is defined, use it to store the last hop in the path + rhs_name = relation["alias"] + else: + rhs_name = ( + f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}" + ) rhs_ident = f"{rhs_name}:{rhs_label}" - self._additional_return(rhs_name) + if relation["include_in_return"]: + self._additional_return(rhs_name) if not stmt: lhs_label = source_class_iterator.__label__ lhs_name = lhs_label.lower() @@ -537,13 +543,14 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: # contains the primary node so _contains() works # as usual self._ast.return_clause = lhs_name - else: + elif relation["include_in_return"]: self._additional_return(lhs_name) else: lhs_ident = stmt rel_ident = self.create_ident() - self._additional_return(rel_ident) + if relation["include_in_return"]: + self._additional_return(rel_ident) stmt = _rel_helper( lhs=lhs_ident, rhs=rhs_ident, @@ -683,7 +690,7 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None): self._ast.where.append(" AND ".join(stmts)) def build_query(self): - query = "" + query: str = "" if self._ast.lookup: query += self._ast.lookup @@ -710,12 +717,20 @@ def build_query(self): query += self._ast.with_clause query += " RETURN " + returned_items: list[str] = [] if self._ast.return_clause: - query += self._ast.return_clause + returned_items.append(self._ast.return_clause) if self._ast.additional_return: - if self._ast.return_clause: - query += ", " - query += ", ".join(self._ast.additional_return) + returned_items += self._ast.additional_return + if hasattr(self.node_set, "_extra_results"): + for varname, vardef in self.node_set._extra_results.items(): + if varname in returned_items: + # We're about to override an existing variable, delete it first to + # avoid duplicate error + returned_items.remove(varname) + returned_items.append(f"{str(vardef)} AS {varname}") + + query += ", ".join(returned_items) if self._ast.order_by: query += " ORDER BY " @@ -754,7 +769,6 @@ async def _count(self): async def _contains(self, node_element_id): # inject id = into ast if not self._ast.return_clause: - print(self._ast.additional_return) self._ast.return_clause = self._ast.additional_return[0] ident = self._ast.return_clause place_holder = self._register_place_holder(ident + "_contains") @@ -881,6 +895,25 @@ class Optional: relation: str +@dataclass +class AggregatingFunction: + """Base aggregating function class.""" + + input_name: str + + +@dataclass +class Collect(AggregatingFunction): + """collect() function.""" + + distinct: bool = False + + def __str__(self): + if self.distinct: + return f"collect(DISTINCT {self.input_name})" + return f"collect({self.input_name})" + + class AsyncNodeSet(AsyncBaseSet): """ A class representing as set of nodes matching common query parameters @@ -908,6 +941,7 @@ def __init__(self, source): self.dont_match = {} self.relations_to_fetch: list = [] + self._extra_results: dict[str] = {} def __await__(self): return self.all().__await__() @@ -1062,18 +1096,56 @@ def order_by(self, *props): return self - def fetch_relations(self, *relation_names): + def _add_relations( + self, *relation_names, include_in_return=True, **aliased_relation_names + ): """Specify a set of relations to return.""" relations = [] - for relation_name in relation_names: - if isinstance(relation_name, Optional): - item = {"path": relation_name.relation, "optional": True} + + def register_relation_to_fetch(relation_def: Any, alias: str = None): + if isinstance(relation_def, Optional): + item = {"path": relation_def.relation, "optional": True} else: - item = {"path": relation_name} + item = {"path": relation_def} + item["include_in_return"] = include_in_return + if alias: + item["alias"] = alias relations.append(item) + + for relation_name in relation_names: + register_relation_to_fetch(relation_name) + for alias, relation_def in aliased_relation_names.items(): + register_relation_to_fetch(relation_def, alias) + self.relations_to_fetch = relations return self + def fetch_relations(self, *relation_names, **aliased_relation_names): + """Specify a set of relations to traverse and return.""" + return self._add_relations(*relation_names, **aliased_relation_names) + + def traverse_relations(self, *relation_names, **aliased_relation_names): + """Specify a set of relations to traverse only.""" + return self._add_relations( + *relation_names, include_in_return=False, **aliased_relation_names + ) + + def annotate(self, *vars, **aliased_vars): + """Annotate node set results with extra variables.""" + + def register_extra_var(vardef, varname: str = None): + if isinstance(vardef, AggregatingFunction): + self._extra_results[varname if varname else vardef.input_name] = vardef + else: + raise NotImplementedError + + for vardef in vars: + register_extra_var(vardef) + for varname, vardef in aliased_vars.items(): + register_extra_var(vardef, varname) + + return self + class AsyncTraversal(AsyncBaseSet): """ diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 928842c2..8f163c53 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -2,7 +2,7 @@ import re from collections import defaultdict from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase @@ -507,7 +507,7 @@ def build_traversal(self, traversal): return traversal_ident - def _additional_return(self, name): + def _additional_return(self, name: str): if name not in self._ast.additional_return and name != self._ast.return_clause: self._ast.additional_return.append(name) @@ -515,7 +515,8 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: path: str = relation["path"] stmt: str = "" source_class_iterator = source_class - for index, part in enumerate(path.split("__")): + parts = path.split("__") + for index, part in enumerate(parts): relationship = getattr(source_class_iterator, part) # build source if "node_class" not in relationship.definition: @@ -523,11 +524,16 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: rhs_label = relationship.definition["node_class"].__label__ rel_reference = f'{relationship.definition["node_class"]}_{part}' self._node_counters[rel_reference] += 1 - rhs_name = ( - f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}" - ) + if index + 1 == len(parts) and "alias" in relation: + # If an alias is defined, use it to store the last hop in the path + rhs_name = relation["alias"] + else: + rhs_name = ( + f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}" + ) rhs_ident = f"{rhs_name}:{rhs_label}" - self._additional_return(rhs_name) + if relation["include_in_return"]: + self._additional_return(rhs_name) if not stmt: lhs_label = source_class_iterator.__label__ lhs_name = lhs_label.lower() @@ -537,13 +543,14 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: # contains the primary node so _contains() works # as usual self._ast.return_clause = lhs_name - else: + elif relation["include_in_return"]: self._additional_return(lhs_name) else: lhs_ident = stmt rel_ident = self.create_ident() - self._additional_return(rel_ident) + if relation["include_in_return"]: + self._additional_return(rel_ident) stmt = _rel_helper( lhs=lhs_ident, rhs=rhs_ident, @@ -683,7 +690,7 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None): self._ast.where.append(" AND ".join(stmts)) def build_query(self): - query = "" + query: str = "" if self._ast.lookup: query += self._ast.lookup @@ -710,12 +717,20 @@ def build_query(self): query += self._ast.with_clause query += " RETURN " + returned_items: list[str] = [] if self._ast.return_clause: - query += self._ast.return_clause + returned_items.append(self._ast.return_clause) if self._ast.additional_return: - if self._ast.return_clause: - query += ", " - query += ", ".join(self._ast.additional_return) + returned_items += self._ast.additional_return + if hasattr(self.node_set, "_extra_results"): + for varname, vardef in self.node_set._extra_results.items(): + if varname in returned_items: + # We're about to override an existing variable, delete it first to + # avoid duplicate error + returned_items.remove(varname) + returned_items.append(f"{str(vardef)} AS {varname}") + + query += ", ".join(returned_items) if self._ast.order_by: query += " ORDER BY " @@ -754,7 +769,6 @@ def _count(self): def _contains(self, node_element_id): # inject id = into ast if not self._ast.return_clause: - print(self._ast.additional_return) self._ast.return_clause = self._ast.additional_return[0] ident = self._ast.return_clause place_holder = self._register_place_holder(ident + "_contains") @@ -877,6 +891,25 @@ class Optional: relation: str +@dataclass +class AggregatingFunction: + """Base aggregating function class.""" + + input_name: str + + +@dataclass +class Collect(AggregatingFunction): + """collect() function.""" + + distinct: bool = False + + def __str__(self): + if self.distinct: + return f"collect(DISTINCT {self.input_name})" + return f"collect({self.input_name})" + + class NodeSet(BaseSet): """ A class representing as set of nodes matching common query parameters @@ -904,6 +937,7 @@ def __init__(self, source): self.dont_match = {} self.relations_to_fetch: list = [] + self._extra_results: dict[str] = {} def __await__(self): return self.all().__await__() @@ -1058,18 +1092,56 @@ def order_by(self, *props): return self - def fetch_relations(self, *relation_names): + def _add_relations( + self, *relation_names, include_in_return=True, **aliased_relation_names + ): """Specify a set of relations to return.""" relations = [] - for relation_name in relation_names: - if isinstance(relation_name, Optional): - item = {"path": relation_name.relation, "optional": True} + + def register_relation_to_fetch(relation_def: Any, alias: str = None): + if isinstance(relation_def, Optional): + item = {"path": relation_def.relation, "optional": True} else: - item = {"path": relation_name} + item = {"path": relation_def} + item["include_in_return"] = include_in_return + if alias: + item["alias"] = alias relations.append(item) + + for relation_name in relation_names: + register_relation_to_fetch(relation_name) + for alias, relation_def in aliased_relation_names.items(): + register_relation_to_fetch(relation_def, alias) + self.relations_to_fetch = relations return self + def fetch_relations(self, *relation_names, **aliased_relation_names): + """Specify a set of relations to traverse and return.""" + return self._add_relations(*relation_names, **aliased_relation_names) + + def traverse_relations(self, *relation_names, **aliased_relation_names): + """Specify a set of relations to traverse only.""" + return self._add_relations( + *relation_names, include_in_return=False, **aliased_relation_names + ) + + def annotate(self, *vars, **aliased_vars): + """Annotate node set results with extra variables.""" + + def register_extra_var(vardef, varname: str = None): + if isinstance(vardef, AggregatingFunction): + self._extra_results[varname if varname else vardef.input_name] = vardef + else: + raise NotImplementedError + + for vardef in vars: + register_extra_var(vardef) + for varname, vardef in aliased_vars.items(): + register_extra_var(vardef, varname) + + return self + class Traversal(BaseSet): """ diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index e3195448..27406f2d 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -15,12 +15,14 @@ Q, StringProperty, UniqueIdProperty, + adb, ) from neomodel._async_compat.util import AsyncUtil from neomodel.async_.match import ( AsyncNodeSet, AsyncQueryBuilder, AsyncTraversal, + Collect, Optional, ) from neomodel.exceptions import MultipleNodesReturned, RelationshipClassNotDefined @@ -595,6 +597,46 @@ async def test_fetch_relations(): ) +@mark_async_test +async def test_annotate_and_collect(): + # Clean DB before we start anything... + await adb.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = await Species(name="Arabica").save() + robusta = await Species(name="Robusta").save() + nescafe = await Coffee(name="Nescafe 1002", price=99).save() + nescafe_gold = await Coffee(name="Nescafe 1003", price=11).save() + + tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + await nescafe.suppliers.connect(tesco) + await nescafe_gold.suppliers.connect(tesco) + await nescafe.species.connect(arabica) + await nescafe_gold.species.connect(robusta) + await nescafe_gold.species.connect(arabica) + + result = ( + await Supplier.nodes.traverse_relations(species="coffees__species") + .annotate(Collect("species")) + .all() + ) + assert len(result) == 1 + assert len(result[0][1][0]) == 3 # 3 species must be there (with 2 duplicates) + + result = ( + await Supplier.nodes.traverse_relations(species="coffees__species") + .annotate(Collect("species", distinct=True)) + .all() + ) + assert len(result[0][1][0]) == 2 # 2 species must be there + + result = ( + await Supplier.nodes.traverse_relations(species="coffees__species") + .annotate(all_species=Collect("species", distinct=True)) + .all() + ) + assert len(result[0][1][0]) == 2 # 2 species must be there + + @mark_async_test async def test_issue_795(): jim = await PersonX(name="Jim", age=3).save() # Create diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 170a7363..f872dad4 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -15,10 +15,11 @@ StructuredNode, StructuredRel, UniqueIdProperty, + db, ) from neomodel._async_compat.util import Util from neomodel.exceptions import MultipleNodesReturned, RelationshipClassNotDefined -from neomodel.sync_.match import NodeSet, Optional, QueryBuilder, Traversal +from neomodel.sync_.match import Collect, NodeSet, Optional, QueryBuilder, Traversal class SupplierRel(StructuredRel): @@ -584,6 +585,46 @@ def test_fetch_relations(): ) +@mark_sync_test +def test_annotate_and_collect(): + # Clean DB before we start anything... + db.cypher_query("MATCH (n) DETACH DELETE n") + + arabica = Species(name="Arabica").save() + robusta = Species(name="Robusta").save() + nescafe = Coffee(name="Nescafe 1002", price=99).save() + nescafe_gold = Coffee(name="Nescafe 1003", price=11).save() + + tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + nescafe.suppliers.connect(tesco) + nescafe_gold.suppliers.connect(tesco) + nescafe.species.connect(arabica) + nescafe_gold.species.connect(robusta) + nescafe_gold.species.connect(arabica) + + result = ( + Supplier.nodes.traverse_relations(species="coffees__species") + .annotate(Collect("species")) + .all() + ) + assert len(result) == 1 + assert len(result[0][1][0]) == 3 # 3 species must be there (with 2 duplicates) + + result = ( + Supplier.nodes.traverse_relations(species="coffees__species") + .annotate(Collect("species", distinct=True)) + .all() + ) + assert len(result[0][1][0]) == 2 # 2 species must be there + + result = ( + Supplier.nodes.traverse_relations(species="coffees__species") + .annotate(all_species=Collect("species", distinct=True)) + .all() + ) + assert len(result[0][1][0]) == 2 # 2 species must be there + + @mark_sync_test def test_issue_795(): jim = PersonX(name="Jim", age=3).save() # Create