Skip to content

Commit

Permalink
Merge pull request #829 from neo4j-contrib/feature/annotate
Browse files Browse the repository at this point in the history
Added support for annotations and calling aggregating functions.
  • Loading branch information
mariusconjeaud authored Sep 23, 2024
2 parents 95566a1 + 572c91c commit 2060309
Show file tree
Hide file tree
Showing 9 changed files with 323 additions and 81 deletions.
3 changes: 3 additions & 0 deletions Changelog
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
Version 5.3.3 2024-09
* Fixes vector index doc and test

Version 5.3.2 2024-06
* Add support for Vector and Fulltext indexes creation
* Add DateTimeNeo4jFormatProperty for Neo4j native datetime format
Expand Down
2 changes: 1 addition & 1 deletion doc/source/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Database Inspection - Requires APOC
===================================
You can inspect an existing Neo4j database to generate a neomodel definition file using the ``inspect`` command::

$ neomodel_inspect_database -db bolt://neo4j_username:neo4j_password@localhost:7687 --write-to yourapp/models.py
$ neomodel_inspect_database --db bolt://neo4j_username:neo4j_password@localhost:7687 --write-to yourapp/models.py

This will generate a file called ``models.py`` in the ``yourapp`` directory. This file can be used as a starting point,
and will contain the necessary module imports, as well as class definition for nodes and, if relevant, relationships.
Expand Down
5 changes: 4 additions & 1 deletion doc/source/schema_management.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ Full example: ::
name = StringProperty(
index=True,
fulltext_index=FulltextIndex(analyzer='english', eventually_consistent=True)
)
name_embedding = ArrayProperty(
FloatProperty(),
vector_index=VectorIndex(dimensions=512, similarity_function='euclidean')
)

Expand All @@ -83,7 +86,7 @@ The following constraints are supported:
- ``unique_index=True``: This will create a uniqueness constraint on the property. Available for both nodes and relationships (Neo4j version 5.7 or higher).

.. note::
The uniquess constraint of Neo4j is not supported as such, but using ``required=True`` on a property serves the same purpose.
The uniqueness constraint of Neo4j is not supported as such, but using ``required=True`` on a property serves the same purpose.


Extracting the schema from a database
Expand Down
112 changes: 92 additions & 20 deletions neomodel/async_/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
from collections import defaultdict
from dataclasses import dataclass
from typing import Optional
from typing import Any, Optional

from neomodel.async_.core import AsyncStructuredNode, adb
from neomodel.exceptions import MultipleNodesReturned
Expand Down Expand Up @@ -507,27 +507,33 @@ async def build_traversal(self, traversal):

return traversal_ident

def _additional_return(self, name):
def _additional_return(self, name: str):
if name not in self._ast.additional_return and name != self._ast.return_clause:
self._ast.additional_return.append(name)

def build_traversal_from_path(self, relation: dict, source_class) -> str:
path: str = relation["path"]
stmt: str = ""
source_class_iterator = source_class
for index, part in enumerate(path.split("__")):
parts = path.split("__")
for index, part in enumerate(parts):
relationship = getattr(source_class_iterator, part)
# build source
if "node_class" not in relationship.definition:
relationship.lookup_node_class()
rhs_label = relationship.definition["node_class"].__label__
rel_reference = f'{relationship.definition["node_class"]}_{part}'
self._node_counters[rel_reference] += 1
rhs_name = (
f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}"
)
if index + 1 == len(parts) and "alias" in relation:
# If an alias is defined, use it to store the last hop in the path
rhs_name = relation["alias"]
else:
rhs_name = (
f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}"
)
rhs_ident = f"{rhs_name}:{rhs_label}"
self._additional_return(rhs_name)
if relation["include_in_return"]:
self._additional_return(rhs_name)
if not stmt:
lhs_label = source_class_iterator.__label__
lhs_name = lhs_label.lower()
Expand All @@ -537,13 +543,14 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str:
# contains the primary node so _contains() works
# as usual
self._ast.return_clause = lhs_name
else:
elif relation["include_in_return"]:
self._additional_return(lhs_name)
else:
lhs_ident = stmt

rel_ident = self.create_ident()
self._additional_return(rel_ident)
if relation["include_in_return"]:
self._additional_return(rel_ident)
stmt = _rel_helper(
lhs=lhs_ident,
rhs=rhs_ident,
Expand Down Expand Up @@ -683,7 +690,7 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None):
self._ast.where.append(" AND ".join(stmts))

def build_query(self):
query = ""
query: str = ""

if self._ast.lookup:
query += self._ast.lookup
Expand All @@ -710,12 +717,20 @@ def build_query(self):
query += self._ast.with_clause

query += " RETURN "
returned_items: list[str] = []
if self._ast.return_clause:
query += self._ast.return_clause
returned_items.append(self._ast.return_clause)
if self._ast.additional_return:
if self._ast.return_clause:
query += ", "
query += ", ".join(self._ast.additional_return)
returned_items += self._ast.additional_return
if hasattr(self.node_set, "_extra_results"):
for varname, vardef in self.node_set._extra_results.items():
if varname in returned_items:
# We're about to override an existing variable, delete it first to
# avoid duplicate error
returned_items.remove(varname)
returned_items.append(f"{str(vardef)} AS {varname}")

query += ", ".join(returned_items)

if self._ast.order_by:
query += " ORDER BY "
Expand Down Expand Up @@ -754,7 +769,6 @@ async def _count(self):
async def _contains(self, node_element_id):
# inject id = into ast
if not self._ast.return_clause:
print(self._ast.additional_return)
self._ast.return_clause = self._ast.additional_return[0]
ident = self._ast.return_clause
place_holder = self._register_place_holder(ident + "_contains")
Expand Down Expand Up @@ -881,6 +895,25 @@ class Optional:
relation: str


@dataclass
class AggregatingFunction:
"""Base aggregating function class."""

input_name: str


@dataclass
class Collect(AggregatingFunction):
"""collect() function."""

distinct: bool = False

def __str__(self):
if self.distinct:
return f"collect(DISTINCT {self.input_name})"
return f"collect({self.input_name})"


class AsyncNodeSet(AsyncBaseSet):
"""
A class representing as set of nodes matching common query parameters
Expand Down Expand Up @@ -908,6 +941,7 @@ def __init__(self, source):
self.dont_match = {}

self.relations_to_fetch: list = []
self._extra_results: dict[str] = {}

def __await__(self):
return self.all().__await__()
Expand Down Expand Up @@ -1062,18 +1096,56 @@ def order_by(self, *props):

return self

def fetch_relations(self, *relation_names):
def _add_relations(
self, *relation_names, include_in_return=True, **aliased_relation_names
):
"""Specify a set of relations to return."""
relations = []
for relation_name in relation_names:
if isinstance(relation_name, Optional):
item = {"path": relation_name.relation, "optional": True}

def register_relation_to_fetch(relation_def: Any, alias: str = None):
if isinstance(relation_def, Optional):
item = {"path": relation_def.relation, "optional": True}
else:
item = {"path": relation_name}
item = {"path": relation_def}
item["include_in_return"] = include_in_return
if alias:
item["alias"] = alias
relations.append(item)

for relation_name in relation_names:
register_relation_to_fetch(relation_name)
for alias, relation_def in aliased_relation_names.items():
register_relation_to_fetch(relation_def, alias)

self.relations_to_fetch = relations
return self

def fetch_relations(self, *relation_names, **aliased_relation_names):
"""Specify a set of relations to traverse and return."""
return self._add_relations(*relation_names, **aliased_relation_names)

def traverse_relations(self, *relation_names, **aliased_relation_names):
"""Specify a set of relations to traverse only."""
return self._add_relations(
*relation_names, include_in_return=False, **aliased_relation_names
)

def annotate(self, *vars, **aliased_vars):
"""Annotate node set results with extra variables."""

def register_extra_var(vardef, varname: str = None):
if isinstance(vardef, AggregatingFunction):
self._extra_results[varname if varname else vardef.input_name] = vardef
else:
raise NotImplementedError

for vardef in vars:
register_extra_var(vardef)
for varname, vardef in aliased_vars.items():
register_extra_var(vardef, varname)

return self


class AsyncTraversal(AsyncBaseSet):
"""
Expand Down
Loading

0 comments on commit 2060309

Please sign in to comment.