Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for annotations and calling aggregating functions. #829

Merged
merged 9 commits into from
Sep 23, 2024
3 changes: 3 additions & 0 deletions Changelog
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
Version 5.3.3 2024-09
* Fixes vector index doc and test

Version 5.3.2 2024-06
* Add support for Vector and Fulltext indexes creation
* Add DateTimeNeo4jFormatProperty for Neo4j native datetime format
Expand Down
2 changes: 1 addition & 1 deletion doc/source/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Database Inspection - Requires APOC
===================================
You can inspect an existing Neo4j database to generate a neomodel definition file using the ``inspect`` command::

$ neomodel_inspect_database -db bolt://neo4j_username:neo4j_password@localhost:7687 --write-to yourapp/models.py
$ neomodel_inspect_database --db bolt://neo4j_username:neo4j_password@localhost:7687 --write-to yourapp/models.py

This will generate a file called ``models.py`` in the ``yourapp`` directory. This file can be used as a starting point,
and will contain the necessary module imports, as well as class definition for nodes and, if relevant, relationships.
Expand Down
5 changes: 4 additions & 1 deletion doc/source/schema_management.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ Full example: ::
name = StringProperty(
index=True,
fulltext_index=FulltextIndex(analyzer='english', eventually_consistent=True)
)
name_embedding = ArrayProperty(
FloatProperty(),
vector_index=VectorIndex(dimensions=512, similarity_function='euclidean')
)

Expand All @@ -83,7 +86,7 @@ The following constraints are supported:
- ``unique_index=True``: This will create a uniqueness constraint on the property. Available for both nodes and relationships (Neo4j version 5.7 or higher).

.. note::
The uniquess constraint of Neo4j is not supported as such, but using ``required=True`` on a property serves the same purpose.
The uniqueness constraint of Neo4j is not supported as such, but using ``required=True`` on a property serves the same purpose.


Extracting the schema from a database
Expand Down
112 changes: 92 additions & 20 deletions neomodel/async_/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
from collections import defaultdict
from dataclasses import dataclass
from typing import Optional
from typing import Any, Optional

from neomodel.async_.core import AsyncStructuredNode, adb
from neomodel.exceptions import MultipleNodesReturned
Expand Down Expand Up @@ -507,27 +507,33 @@ async def build_traversal(self, traversal):

return traversal_ident

def _additional_return(self, name):
def _additional_return(self, name: str):
if name not in self._ast.additional_return and name != self._ast.return_clause:
self._ast.additional_return.append(name)

def build_traversal_from_path(self, relation: dict, source_class) -> str:
path: str = relation["path"]
stmt: str = ""
source_class_iterator = source_class
for index, part in enumerate(path.split("__")):
parts = path.split("__")
for index, part in enumerate(parts):
relationship = getattr(source_class_iterator, part)
# build source
if "node_class" not in relationship.definition:
relationship.lookup_node_class()
rhs_label = relationship.definition["node_class"].__label__
rel_reference = f'{relationship.definition["node_class"]}_{part}'
self._node_counters[rel_reference] += 1
rhs_name = (
f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}"
)
if index + 1 == len(parts) and "alias" in relation:
# If an alias is defined, use it to store the last hop in the path
rhs_name = relation["alias"]
else:
rhs_name = (
f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}"
)
rhs_ident = f"{rhs_name}:{rhs_label}"
self._additional_return(rhs_name)
if relation["include_in_return"]:
self._additional_return(rhs_name)
if not stmt:
lhs_label = source_class_iterator.__label__
lhs_name = lhs_label.lower()
Expand All @@ -537,13 +543,14 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str:
# contains the primary node so _contains() works
# as usual
self._ast.return_clause = lhs_name
else:
elif relation["include_in_return"]:
self._additional_return(lhs_name)
else:
lhs_ident = stmt

rel_ident = self.create_ident()
self._additional_return(rel_ident)
if relation["include_in_return"]:
self._additional_return(rel_ident)
stmt = _rel_helper(
lhs=lhs_ident,
rhs=rhs_ident,
Expand Down Expand Up @@ -683,7 +690,7 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None):
self._ast.where.append(" AND ".join(stmts))

def build_query(self):
query = ""
query: str = ""

if self._ast.lookup:
query += self._ast.lookup
Expand All @@ -710,12 +717,20 @@ def build_query(self):
query += self._ast.with_clause

query += " RETURN "
returned_items: list[str] = []
if self._ast.return_clause:
query += self._ast.return_clause
returned_items.append(self._ast.return_clause)
if self._ast.additional_return:
if self._ast.return_clause:
query += ", "
query += ", ".join(self._ast.additional_return)
returned_items += self._ast.additional_return
if hasattr(self.node_set, "_extra_results"):
for varname, vardef in self.node_set._extra_results.items():
if varname in returned_items:
# We're about to override an existing variable, delete it first to
# avoid duplicate error
returned_items.remove(varname)
returned_items.append(f"{str(vardef)} AS {varname}")

query += ", ".join(returned_items)

if self._ast.order_by:
query += " ORDER BY "
Expand Down Expand Up @@ -754,7 +769,6 @@ async def _count(self):
async def _contains(self, node_element_id):
# inject id = into ast
if not self._ast.return_clause:
print(self._ast.additional_return)
self._ast.return_clause = self._ast.additional_return[0]
ident = self._ast.return_clause
place_holder = self._register_place_holder(ident + "_contains")
Expand Down Expand Up @@ -881,6 +895,25 @@ class Optional:
relation: str


@dataclass
class AggregatingFunction:
"""Base aggregating function class."""

input_name: str


@dataclass
class Collect(AggregatingFunction):
"""collect() function."""

distinct: bool = False

def __str__(self):
if self.distinct:
return f"collect(DISTINCT {self.input_name})"
return f"collect({self.input_name})"


class AsyncNodeSet(AsyncBaseSet):
"""
A class representing as set of nodes matching common query parameters
Expand Down Expand Up @@ -908,6 +941,7 @@ def __init__(self, source):
self.dont_match = {}

self.relations_to_fetch: list = []
self._extra_results: dict[str] = {}

def __await__(self):
return self.all().__await__()
Expand Down Expand Up @@ -1062,18 +1096,56 @@ def order_by(self, *props):

return self

def fetch_relations(self, *relation_names):
def _add_relations(
self, *relation_names, include_in_return=True, **aliased_relation_names
):
"""Specify a set of relations to return."""
relations = []
for relation_name in relation_names:
if isinstance(relation_name, Optional):
item = {"path": relation_name.relation, "optional": True}

def register_relation_to_fetch(relation_def: Any, alias: str = None):
if isinstance(relation_def, Optional):
item = {"path": relation_def.relation, "optional": True}
else:
item = {"path": relation_name}
item = {"path": relation_def}
item["include_in_return"] = include_in_return
if alias:
item["alias"] = alias
relations.append(item)

for relation_name in relation_names:
register_relation_to_fetch(relation_name)
for alias, relation_def in aliased_relation_names.items():
register_relation_to_fetch(relation_def, alias)

self.relations_to_fetch = relations
return self

def fetch_relations(self, *relation_names, **aliased_relation_names):
"""Specify a set of relations to traverse and return."""
return self._add_relations(*relation_names, **aliased_relation_names)

def traverse_relations(self, *relation_names, **aliased_relation_names):
"""Specify a set of relations to traverse only."""
return self._add_relations(
*relation_names, include_in_return=False, **aliased_relation_names
)

def annotate(self, *vars, **aliased_vars):
"""Annotate node set results with extra variables."""

def register_extra_var(vardef, varname: str = None):
if isinstance(vardef, AggregatingFunction):
self._extra_results[varname if varname else vardef.input_name] = vardef
else:
raise NotImplementedError

for vardef in vars:
register_extra_var(vardef)
for varname, vardef in aliased_vars.items():
register_extra_var(vardef, varname)

return self


class AsyncTraversal(AsyncBaseSet):
"""
Expand Down
Loading
Loading