diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml
index 032a5868..1a0450f7 100644
--- a/.github/workflows/integration-tests.yml
+++ b/.github/workflows/integration-tests.yml
@@ -15,7 +15,7 @@ jobs:
strategy:
fail-fast: false
matrix:
- python-version: ["3.12", "3.11", "3.10", "3.9", "3.8", "3.7"]
+ python-version: ["3.13", "3.12", "3.11", "3.10", "3.9"]
neo4j-version: ["community", "enterprise", "5.5-enterprise", "4.4-enterprise", "4.4-community"]
steps:
diff --git a/Changelog b/Changelog
index b65b90a7..079cab72 100644
--- a/Changelog
+++ b/Changelog
@@ -1,3 +1,15 @@
+Version 5.4.0 2024-11
+* Traversal option for filtering and ordering
+* Insert raw Cypher for ordering
+* Possibility to traverse relations, only returning the last element of the path
+* Resolve the results of complex queries as a nested subgraph
+* Possibility to transform variables, with aggregations methods : Collect() and Last()
+* Intermediate transform, for example to order variables before collecting
+* Subqueries (Cypher CALL{} clause)
+* Allow JSONProperty to actually use non-ascii elements. Thanks to @danikirish
+* Bumped neo4j (driver) to 5.26.0
+* Special huge thanks to @tonioo for this release
+
Version 5.3.3 2024-09
* Fixes vector index doc and test
diff --git a/README.md b/README.md
index 633cb2f1..f72d5a41 100644
--- a/README.md
+++ b/README.md
@@ -24,7 +24,7 @@ GitHub repo found at .
**For neomodel releases 5.x :**
-- Python 3.7+
+- Python 3.8+
- Neo4j 5.x, 4.4 (LTS)
**For neomodel releases 4.x :**
@@ -37,6 +37,14 @@ GitHub repo found at .
Available on
[readthedocs](http://neomodel.readthedocs.org).
+# New in 5.4.0
+
+This version adds many new features, expanding neomodel's querying capabilities. Those features were kindly contributed back by the [OpenStudyBuilder team](https://openstudybuilder.com/). A VERY special thanks to @tonioo for the integration work.
+
+There are too many new capabilities here, so I advise you to start by looking at the full summary example in the [Getting Started guide](https://neomodel.readthedocs.io/en/latest/getting_started.html#full-example). It will then point you to the various relevant sections.
+
+We also validated support for [Python 3.13](https://docs.python.org/3/whatsnew/3.13.html).
+
# New in 5.3.0
neomodel now supports asynchronous programming, thanks to the [Neo4j driver async API](https://neo4j.com/docs/api/python-driver/current/async_api.html). The [documentation](http://neomodel.readthedocs.org) has been updated accordingly, with an updated getting started section, and some specific documentation for the async API.
@@ -96,7 +104,7 @@ Ensure `dbms.security.auth_enabled=true` in your database configuration
file. Setup a virtual environment, install neomodel for development and
run the test suite: :
- $ pip install -e '.[dev,pandas,numpy]'
+ $ pip install -r requirements-dev.txt
$ pytest
The tests in \"test_connection.py\" will fail locally if you don\'t
diff --git a/doc/source/advanced_query_operations.rst b/doc/source/advanced_query_operations.rst
new file mode 100644
index 00000000..e602d479
--- /dev/null
+++ b/doc/source/advanced_query_operations.rst
@@ -0,0 +1,111 @@
+.. _Advanced query operations:
+
+=========================
+Advanced query operations
+=========================
+
+neomodel provides ways to enhance your queries beyond filtering and traversals.
+
+Annotate - Aliasing
+-------------------
+
+The `annotate` method allows you to add transformations to your elements. To learn more about the available transformations, keep reading this section.
+
+Aggregations
+------------
+
+neomodel implements some of the aggregation methods available in Cypher:
+
+- Collect (with distinct option)
+- Last
+
+These are usable in this way::
+
+ from neomodel.sync_.match import Collect, Last
+
+ # distinct is optional, and defaults to False. When true, objects are deduplicated
+ Supplier.nodes.traverse_relations(available_species="coffees__species")
+ .annotate(Collect("available_species", distinct=True))
+ .all()
+
+ # Last is used to get the last element of a list
+ Supplier.nodes.traverse_relations(available_species="coffees__species")
+ .annotate(Last(Collect("last_species")))
+ .all()
+
+Note how `annotate` is used to add the aggregation method to the query.
+
+.. note::
+ Using the Last() method right after a Collect() without having set an ordering will return the last element in the list as it was returned by the database.
+
+ This is probably not what you want ; which means you must provide an explicit ordering. To do so, you cannot use neomodel's `order_by` method, but need an intermediate transformation step (see below).
+
+ This is because the order_by method adds ordering as the very last step of the Cypher query ; whereas in the present example, you want to first order Species, then get the last one, and then finally return your results. In other words, you need an intermediate WITH Cypher clause.
+
+Intermediate transformations
+----------------------------
+
+The `intermediate_transform` method basically allows you to add a WITH clause to your query. This is useful when you need to perform some operations on your results before returning them.
+
+As discussed in the note above, this is for example useful when you need to order your results before applying an aggregation method, like so::
+
+ from neomodel.sync_.match import Collect, Last
+
+ # This will return all Coffee nodes, with their most expensive supplier
+ Coffee.nodes.traverse_relations(suppliers="suppliers")
+ .intermediate_transform(
+ {"suppliers": "suppliers"}, ordering=["suppliers.delivery_cost"]
+ )
+ .annotate(supps=Last(Collect("suppliers")))
+
+Subqueries
+----------
+
+The `subquery` method allows you to perform a `Cypher subquery `_ inside your query. This allows you to perform operations in isolation to the rest of your query::
+
+ from neomodel.sync_match import Collect, Last
+
+ # This will create a CALL{} subquery
+ # And return a variable named supps usable in the rest of your query
+ Coffee.nodes.filter(name="Espresso")
+ .subquery(
+ Coffee.nodes.traverse_relations(suppliers="suppliers")
+ .intermediate_transform(
+ {"suppliers": "suppliers"}, ordering=["suppliers.delivery_cost"]
+ )
+ .annotate(supps=Last(Collect("suppliers"))),
+ ["supps"],
+ )
+
+.. note::
+ Notice the subquery starts with Coffee.nodes ; neomodel will use this to know it needs to inject the source "coffee" variable generated by the outer query into the subquery. This means only Espresso coffee nodes will be considered in the subquery.
+
+ We know this is confusing to read, but have not found a better wat to do this yet. If you have any suggestions, please let us know.
+
+Helpers
+-------
+
+Reading the sections above, you may have noticed that we used explicit aliasing in the examples, as in::
+
+ traverse_relations(suppliers="suppliers")
+
+This allows you to reference the generated Cypher variables in your transformation steps, for example::
+
+ traverse_relations(suppliers="suppliers").annotate(Collect("suppliers"))
+
+In some cases though, it is not possible to set explicit aliases, for example when using `fetch_relations`. In these cases, neomodel provides `resolver` methods, so you do not have to guess the name of the variable in the generated Cypher. Those are `NodeNameResolver` and `RelationshipNameResolver`. For example::
+
+ from neomodel.sync_match import Collect, NodeNameResolver, RelationshipNameResolver
+
+ Supplier.nodes.fetch_relations("coffees__species")
+ .annotate(
+ all_species=Collect(NodeNameResolver("coffees__species"), distinct=True),
+ all_species_rels=Collect(
+ RelationNameResolver("coffees__species"), distinct=True
+ ),
+ )
+ .all()
+
+.. note::
+
+ When using the resolvers in combination with a traversal as in the example above, it will resolve the variable name of the last element in the traversal - the Species node for NodeNameResolver, and Coffee--Species relationship for RelationshipNameResolver.
\ No newline at end of file
diff --git a/doc/source/configuration.rst b/doc/source/configuration.rst
index e5d5d5d8..511318c4 100644
--- a/doc/source/configuration.rst
+++ b/doc/source/configuration.rst
@@ -32,7 +32,7 @@ Adjust driver configuration - these options are only available for this connecti
config.MAX_TRANSACTION_RETRY_TIME = 30.0 # default
config.RESOLVER = None # default
config.TRUST = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES # default
- config.USER_AGENT = neomodel/v5.3.3 # default
+ config.USER_AGENT = neomodel/v5.4.0 # default
Setting the database name, if different from the default one::
diff --git a/doc/source/cypher.rst b/doc/source/cypher.rst
index f8c7ccaf..8ce2a42e 100644
--- a/doc/source/cypher.rst
+++ b/doc/source/cypher.rst
@@ -24,6 +24,18 @@ Outside of a `StructuredNode`::
The ``resolve_objects`` parameter automatically inflates the returned nodes to their defined classes (this is turned **off** by default). See :ref:`automatic_class_resolution` for details and possible pitfalls.
+You can also retrieve a whole path of already instantiated objects corresponding to
+the nodes and relationship classes with a single query::
+
+ q = db.cypher_query("MATCH p=(:CityOfResidence)<-[:LIVES_IN]-(:PersonOfInterest)-[:IS_FROM]->(:CountryOfOrigin) RETURN p LIMIT 1",
+ resolve_objects = True)
+
+Notice here that ``resolve_objects`` is set to ``True``. This results in ``q`` being a
+list of ``result, result_name`` and ``q[0][0][0]`` being a ``NeomodelPath`` object.
+
+``NeomodelPath`` ``nodes, relationships`` attributes contain already instantiated objects of the
+nodes and relationships in the query, *in order of appearance*.
+
Integrations
============
diff --git a/doc/source/filtering_ordering.rst b/doc/source/filtering_ordering.rst
new file mode 100644
index 00000000..41a37581
--- /dev/null
+++ b/doc/source/filtering_ordering.rst
@@ -0,0 +1,199 @@
+.. _Filtering and ordering:
+
+======================
+Filtering and ordering
+======================
+
+For the examples in this section, we will be using the following model::
+
+ class SupplierRel(StructuredRel):
+ since = DateTimeProperty(default=datetime.now)
+
+
+ class Supplier(StructuredNode):
+ name = StringProperty()
+ delivery_cost = IntegerProperty()
+
+
+ class Coffee(StructuredNode):
+ name = StringProperty(unique_index=True)
+ price = IntegerProperty()
+ suppliers = RelationshipFrom(Supplier, 'SUPPLIES', model=SupplierRel)
+
+Filtering
+=========
+
+neomodel allows filtering on nodes' and relationships' properties. Filters can be combined using Django's Q syntax. It also allows multi-hop relationship traversals to filter on "remote" elements.
+
+Filter methods
+--------------
+
+The ``.nodes`` property of a class returns all nodes of that type from the database.
+
+This set (called `NodeSet`) can be iterated over and filtered on, using the `.filter` method::
+
+ # nodes with label Coffee whose price is greater than 2
+ high_end_coffees = Coffee.nodes.filter(price__gt=2)
+
+ try:
+ java = Coffee.nodes.get(name='Java')
+ except DoesNotExist:
+ # .filter will not throw an exception if no results are found
+ # but .get will
+ print("Couldn't find coffee 'Java'")
+
+The filter method borrows the same Django filter format with double underscore prefixed operators:
+
+- lt - less than
+- gt - greater than
+- lte - less than or equal to
+- gte - greater than or equal to
+- ne - not equal
+- in - item in list
+- isnull - `True` IS NULL, `False` IS NOT NULL
+- exact - string equals
+- iexact - string equals, case insensitive
+- contains - contains string value
+- icontains - contains string value, case insensitive
+- startswith - starts with string value
+- istartswith - starts with string value, case insensitive
+- endswith - ends with string value
+- iendswith - ends with string value, case insensitive
+- regex - matches a regex expression
+- iregex - matches a regex expression, case insensitive
+
+These operators work with both `.get` and `.filter` methods.
+
+Combining filters
+-----------------
+
+The filter method allows you to combine multiple filters::
+
+ cheap_arabicas = Coffee.nodes.filter(price__lt=5, name__icontains='arabica')
+
+These filters are combined using the logical AND operator. To execute more complex logic (for example, queries with OR statements), `Q objects ` can be used. This is borrowed from Django.
+
+``Q`` objects can be combined using the ``&`` and ``|`` operators. Statements of arbitrary complexity can be composed by combining ``Q`` objects
+with the ``&`` and ``|`` operators and use parenthetical grouping. Also, ``Q``
+objects can be negated using the ``~`` operator, allowing for combined lookups
+that combine both a normal query and a negated (``NOT``) query::
+
+ Q(name__icontains='arabica') | ~Q(name__endswith='blend')
+
+Chaining ``Q`` objects will join them as an AND clause::
+
+ not_middle_priced_arabicas = Coffee.nodes.filter(
+ Q(name__icontains='arabica'),
+ Q(price__lt=5) | Q(price__gt=10)
+ )
+
+Traversals and filtering
+------------------------
+
+Sometimes you need to filter nodes based on other nodes they are connected to. This can be done by including a traversal in the `filter` method. ::
+
+ # Find all suppliers of coffee 'Java' who have been supplying since 2007
+ # But whose prices are greater than 5
+ since_date = datetime(2007, 1, 1)
+ java_old_timers = Coffee.nodes.filter(
+ name='Java',
+ suppliers__delivery_cost__gt=5,
+ **{"suppliers|since__lt": since_date}
+ )
+
+In the example above, note the following syntax elements:
+
+- The name of relationships as defined in the `StructuredNode` class is used to traverse relationships. `suppliers` in this example.
+- Double underscore `__` is used to target a property of a node. `delivery_cost` in this example.
+- A pipe `|` is used to separate the relationship traversal from the property filter. The filter also has to included in a `**kwargs` dictionary, because the pipe character would break the syntax. This is a special syntax to indicate that the filter is on the relationship itself, not on the node at the end of the relationship.
+- The filter operators like lt, gt, etc. can be used on the filtered property.
+
+Traversals can be of any length, with each relationships separated by a double underscore `__`, for example::
+
+ # country is here a relationship between Supplier and Country
+ Coffee.nodes.filter(suppliers__country__name='Brazil')
+
+Enforcing relationship/path existence
+-------------------------------------
+
+The `has` method checks for existence of (one or more) relationships, in this case it returns a set of `Coffee` nodes which have a supplier::
+
+ Coffee.nodes.has(suppliers=True)
+
+This can be negated by setting `suppliers=False`, to find `Coffee` nodes without `suppliers`.
+
+You can also filter on the existence of more complex traversals by using the `traverse_relations` method. See :ref:`Path traversal`.
+
+Ordering
+========
+
+neomodel allows ordering by nodes' and relationships' properties. Order can be ascending or descending. Is also allows multi-hop relationship traversals to order on "remote" elements. Finally, you can inject raw Cypher clauses to have full control over ordering when necessary.
+
+order_by
+--------
+
+Ordering results by a particular property is done via the `order_by` method::
+
+ # Ascending sort
+ for coffee in Coffee.nodes.order_by('price'):
+ print(coffee, coffee.price)
+
+ # Descending sort
+ for supplier in Supplier.nodes.order_by('-delivery_cost'):
+ print(supplier, supplier.delivery_cost)
+
+
+Removing the ordering from a previously defined query, is done by passing `None` to `order_by`::
+
+ # Sort in descending order
+ suppliers = Supplier.nodes.order_by('-delivery_cost')
+
+ # Don't order; yield nodes in the order neo4j returns them
+ suppliers = suppliers.order_by(None)
+
+For random ordering simply pass '?' to the order_by method::
+
+ Coffee.nodes.order_by('?')
+
+Traversals and ordering
+-----------------------
+
+Sometimes you need to order results based on properties situated on different nodes or relationships. This can be done by including a traversal in the `order_by` method. ::
+
+ # Find the most expensive coffee to deliver
+ # Then order by the date the supplier started supplying
+ Coffee.nodes.order_by(
+ '-suppliers__delivery_cost',
+ 'suppliers|since',
+ )
+
+In the example above, note the following syntax elements:
+
+- The name of relationships as defined in the `StructuredNode` class is used to traverse relationships. `suppliers` in this example.
+- Double underscore `__` is used to target a property of a node. `delivery_cost` in this example.
+- A pipe `|` is used to separate the relationship traversal from the property filter. This is a special syntax to indicate that the filter is on the relationship itself, not on the node at the end of the relationship.
+
+Traversals can be of any length, with each relationships separated by a double underscore `__`, for example::
+
+ # country is here a relationship between Supplier and Country
+ Coffee.nodes.order_by('suppliers__country__latitude')
+
+RawCypher
+---------
+
+When you need more advanced ordering capabilities, for example to apply order to a transformed property, you can use the `RawCypher` method, like so::
+
+ from neomodel.sync_.match import RawCypher
+
+ class SoftwareDependency(AsyncStructuredNode):
+ name = StringProperty()
+ version = StringProperty()
+
+ SoftwareDependency(name="Package2", version="1.4.0").save()
+ SoftwareDependency(name="Package3", version="2.5.5").save()
+
+ latest_dep = SoftwareDependency.nodes.order_by(
+ RawCypher("toInteger(split($n.version, '.')[0]) DESC"),
+ )
+
+In the example above, note the `$n` placeholder in the `RawCypher` clause. This is a placeholder for the node being ordered (`SoftwareDependency` in this case).
diff --git a/doc/source/getting_started.rst b/doc/source/getting_started.rst
index 6e8a5aa0..dffa94ec 100644
--- a/doc/source/getting_started.rst
+++ b/doc/source/getting_started.rst
@@ -193,6 +193,28 @@ simply returning the node IDs rather than every attribute associated with that N
# Return set of nodes
people = Person.nodes.filter(age__gt=3)
+Iteration, slicing and more
+---------------------------
+
+Iteration, slicing and counting is also supported::
+
+ # Iterable
+ for coffee in Coffee.nodes:
+ print coffee.name
+
+ # Sliceable using python slice syntax
+ coffee = Coffee.nodes.filter(price__gt=2)[2:]
+
+The slice syntax returns a NodeSet object which can in turn be chained.
+
+Length and boolean methods do not return NodeSet objects and cannot be chained further::
+
+ # Count with __len__
+ print len(Coffee.nodes.filter(price__gt=2))
+
+ if Coffee.nodes:
+ print "We have coffee nodes!"
+
Relationships
=============
@@ -236,38 +258,6 @@ Working with relationships::
Retrieving additional relations
===============================
-To avoid queries multiplication, you have the possibility to retrieve
-additional relations with a single call::
-
- # The following call will generate one MATCH with traversal per
- # item in .fetch_relations() call
- results = Person.nodes.fetch_relations('country').all()
- for result in results:
- print(result[0]) # Person
- print(result[1]) # associated Country
-
-You can traverse more than one hop in your relations using the
-following syntax::
-
- # Go from person to City then Country
- Person.nodes.fetch_relations('city__country').all()
-
-You can also force the use of an ``OPTIONAL MATCH`` statement using
-the following syntax::
-
- from neomodel.match import Optional
-
- results = Person.nodes.fetch_relations(Optional('country')).all()
-
-.. note::
-
- Any relationship that you intend to traverse using this method **MUST have a model defined**, even if only the default StructuredRel, like::
-
- class Person(StructuredNode):
- country = RelationshipTo(Country, 'IS_FROM', model=StructuredRel)
-
- Otherwise, neomodel will not be able to determine which relationship model to resolve into, and will fail.
-
.. note::
You can fetch one or more relations within the same call
@@ -339,3 +329,93 @@ Most _dunder_ methods for nodes and relationships had to be overriden to support
# Sync equivalent - __getitem__
assert len(list(Coffee.nodes[1:])) == 2
+
+Full example
+============
+
+The example below will show you how you can mix and match query operations, as described in :ref:`Filtering and ordering`, :ref:`Path traversal`, or :ref:`Advanced query operations`::
+
+ # These are the class definitions used for the query below
+ class HasCourseRel(AsyncStructuredRel):
+ level = StringProperty()
+ start_date = DateTimeProperty()
+ end_date = DateTimeProperty()
+
+
+ class Course(AsyncStructuredNode):
+ name = StringProperty()
+
+
+ class Building(AsyncStructuredNode):
+ name = StringProperty()
+
+
+ class Student(AsyncStructuredNode):
+ name = StringProperty()
+
+ parents = AsyncRelationshipTo("Student", "HAS_PARENT", model=AsyncStructuredRel)
+ children = AsyncRelationshipFrom("Student", "HAS_PARENT", model=AsyncStructuredRel)
+ lives_in = AsyncRelationshipTo(Building, "LIVES_IN", model=AsyncStructuredRel)
+ courses = AsyncRelationshipTo(Course, "HAS_COURSE", model=HasCourseRel)
+ preferred_course = AsyncRelationshipTo(
+ Course,
+ "HAS_PREFERRED_COURSE",
+ model=AsyncStructuredRel,
+ cardinality=AsyncZeroOrOne,
+ )
+
+ # This is the query
+ full_nodeset = (
+ await Student.nodes.filter(name__istartswith="m", lives_in__name="Eiffel Tower") # Combine filters
+ .order_by("name")
+ .fetch_relations(
+ "parents",
+ Optional("children__preferred_course"),
+ ) # Combine fetch_relations
+ .subquery(
+ Student.nodes.fetch_relations("courses") # Root variable student will be auto-injected here
+ .intermediate_transform(
+ {"rel": RelationNameResolver("courses")},
+ ordering=[
+ RawCypher("toInteger(split(rel.level, '.')[0])"),
+ RawCypher("toInteger(split(rel.level, '.')[1])"),
+ "rel.end_date",
+ "rel.start_date",
+ ], # Intermediate ordering
+ )
+ .annotate(
+ latest_course=Last(Collect("rel")),
+ ),
+ ["latest_course"],
+ )
+ )
+
+ # Using async, we need to do 2 await
+ # One is for subquery, the other is for resolve_subgraph
+ # It only runs a single Cypher query though
+ subgraph = await full_nodeset.annotate(
+ children=Collect(NodeNameResolver("children"), distinct=True),
+ children_preferred_course=Collect(
+ NodeNameResolver("children__preferred_course"), distinct=True
+ ),
+ ).resolve_subgraph()
+
+ # The generated Cypher query looks like this
+ query = """
+ MATCH (student:Student)-[r1:`HAS_PARENT`]->(student_parents:Student)
+ MATCH (student)-[r4:`LIVES_IN`]->(building_lives_in:Building)
+ OPTIONAL MATCH (student)<-[r2:`HAS_PARENT`]-(student_children:Student)-[r3:`HAS_PREFERRED_COURSE`]->(course_children__preferred_course:Course)
+ WITH *
+ # building_lives_in_name_1 = "Eiffel Tower"
+ # student_name_1 = "(?i)m.*"
+ WHERE building_lives_in.name = $building_lives_in_name_1 AND student.name =~ $student_name_1
+ CALL {
+ WITH student
+ MATCH (student)-[r1:`HAS_COURSE`]->(course_courses:Course)
+ WITH r1 AS rel
+ ORDER BY toInteger(split(rel.level, '.')[0]),toInteger(split(rel.level, '.')[1]),rel.end_date,rel.start_date
+ RETURN last(collect(rel)) AS latest_course
+ }
+ RETURN latest_course, student, student_parents, r1, student_children, r2, course_children__preferred_course, r3, building_lives_in, r4, collect(DISTINCT student_children) AS children, collect(DISTINCT course_children__preferred_course) AS children_preferred_course
+ ORDER BY student.name
+ """
\ No newline at end of file
diff --git a/doc/source/index.rst b/doc/source/index.rst
index 91a728c0..068e2d93 100644
--- a/doc/source/index.rst
+++ b/doc/source/index.rst
@@ -74,7 +74,9 @@ Contents
properties
spatial_properties
schema_management
- queries
+ filtering_ordering
+ traversal
+ advanced_query_operations
cypher
transactions
hooks
diff --git a/doc/source/queries.rst b/doc/source/queries.rst
deleted file mode 100644
index 4c77a791..00000000
--- a/doc/source/queries.rst
+++ /dev/null
@@ -1,258 +0,0 @@
-================
-Advanced queries
-================
-
-Neomodel contains an API for querying sets of nodes without having to write cypher::
-
- class SupplierRel(StructuredRel):
- since = DateTimeProperty(default=datetime.now)
-
-
- class Supplier(StructuredNode):
- name = StringProperty()
- delivery_cost = IntegerProperty()
- coffees = RelationshipTo('Coffee', 'SUPPLIES')
-
-
- class Coffee(StructuredNode):
- name = StringProperty(unique_index=True)
- price = IntegerProperty()
- suppliers = RelationshipFrom(Supplier, 'SUPPLIES', model=SupplierRel)
-
-Node sets and filtering
-=======================
-
-The ``.nodes`` property of a class returns all nodes of that type from the database.
-
-This set (or `NodeSet`) can be iterated over and filtered on. Under the hood it uses labels introduced in Neo4J 2::
-
- # nodes with label Coffee whose price is greater than 2
- Coffee.nodes.filter(price__gt=2)
-
- try:
- java = Coffee.nodes.get(name='Java')
- except Coffee.DoesNotExist:
- print "Couldn't find coffee 'Java'"
-
-The filter method borrows the same Django filter format with double underscore prefixed operators:
-
-- lt - less than
-- gt - greater than
-- lte - less than or equal to
-- gte - greater than or equal to
-- ne - not equal
-- in - item in list
-- isnull - `True` IS NULL, `False` IS NOT NULL
-- exact - string equals
-- iexact - string equals, case insensitive
-- contains - contains string value
-- icontains - contains string value, case insensitive
-- startswith - starts with string value
-- istartswith - starts with string value, case insensitive
-- endswith - ends with string value
-- iendswith - ends with string value, case insensitive
-- regex - matches a regex expression
-- iregex - matches a regex expression, case insensitive
-
-Complex lookups with ``Q`` objects
-==================================
-
-Keyword argument queries -- in `filter`,
-etc. -- are "AND"ed together. To execute more complex queries (for
-example, queries with ``OR`` statements), `Q objects ` can
-be used.
-
-A `Q object` (``neomodel.Q``) is an object
-used to encapsulate a collection of keyword arguments. These keyword arguments
-are specified as in "Field lookups" above.
-
-For example, this ``Q`` object encapsulates a single ``LIKE`` query::
-
- from neomodel import Q
- Q(name__startswith='Py')
-
-``Q`` objects can be combined using the ``&`` and ``|`` operators. When an
-operator is used on two ``Q`` objects, it yields a new ``Q`` object.
-
-For example, this statement yields a single ``Q`` object that represents the
-"OR" of two ``"name__startswith"`` queries::
-
- Q(name__startswith='Py') | Q(name__startswith='Jav')
-
-This is equivalent to the following SQL ``WHERE`` clause::
-
- WHERE name STARTS WITH 'Py' OR name STARTS WITH 'Jav'
-
-Statements of arbitrary complexity can be composed by combining ``Q`` objects
-with the ``&`` and ``|`` operators and use parenthetical grouping. Also, ``Q``
-objects can be negated using the ``~`` operator, allowing for combined lookups
-that combine both a normal query and a negated (``NOT``) query::
-
- Q(name__startswith='Py') | ~Q(year=2005)
-
-Each lookup function that takes keyword-arguments
-(e.g. `filter`, `exclude`, `get`) can also be passed one or more
-``Q`` objects as positional (not-named) arguments. If multiple
-``Q`` object arguments are provided to a lookup function, the arguments will be "AND"ed
-together. For example::
-
- Lang.nodes.filter(
- Q(name__startswith='Py'),
- Q(year=2005) | Q(year=2006)
- )
-
-This roughly translates to the following Cypher query::
-
- MATCH (lang:Lang) WHERE name STARTS WITH 'Py'
- AND (year = 2005 OR year = 2006)
- return lang;
-
-Lookup functions can mix the use of ``Q`` objects and keyword arguments. All
-arguments provided to a lookup function (be they keyword arguments or ``Q``
-objects) are "AND"ed together. However, if a ``Q`` object is provided, it must
-precede the definition of any keyword arguments. For example::
-
- Lang.nodes.get(
- Q(year=2005) | Q(year=2006),
- name__startswith='Py',
- )
-
-This would be a valid query, equivalent to the previous example;
-
-Has a relationship
-==================
-
-The `has` method checks for existence of (one or more) relationships, in this case it returns a set of `Coffee` nodes which have a supplier::
-
- Coffee.nodes.has(suppliers=True)
-
-This can be negated by setting `suppliers=False`, to find `Coffee` nodes without `suppliers`.
-
-Iteration, slicing and more
-===========================
-
-Iteration, slicing and counting is also supported::
-
- # Iterable
- for coffee in Coffee.nodes:
- print coffee.name
-
- # Sliceable using python slice syntax
- coffee = Coffee.nodes.filter(price__gt=2)[2:]
-
-The slice syntax returns a NodeSet object which can in turn be chained.
-
-Length and boolean methods dont return NodeSet objects and cannot be chained further::
-
- # Count with __len__
- print len(Coffee.nodes.filter(price__gt=2))
-
- if Coffee.nodes:
- print "We have coffee nodes!"
-
-Filtering by relationship properties
-====================================
-
-Filtering on relationship properties is also possible using the `match` method. Note that again these relationships must have a definition.::
-
- coffee_brand = Coffee.nodes.get(name="BestCoffeeEver")
-
- for supplier in coffee_brand.suppliers.match(since_lt=january):
- print(supplier.name)
-
-Ordering by property
-====================
-
-Ordering results by a particular property is done via th `order_by` method::
-
- # Ascending sort
- for coffee in Coffee.nodes.order_by('price'):
- print(coffee, coffee.price)
-
- # Descending sort
- for supplier in Supplier.nodes.order_by('-delivery_cost'):
- print(supplier, supplier.delivery_cost)
-
-
-Removing the ordering from a previously defined query, is done by passing `None` to `order_by`::
-
- # Sort in descending order
- suppliers = Supplier.nodes.order_by('-delivery_cost')
-
- # Don't order; yield nodes in the order neo4j returns them
- suppliers = suppliers.order_by(None)
-
-For random ordering simply pass '?' to the order_by method::
-
- Coffee.nodes.order_by('?')
-
-Retrieving paths
-================
-
-You can retrieve a whole path of already instantiated objects corresponding to
-the nodes and relationship classes with a single query.
-
-Suppose the following schema:
-
-::
-
- class PersonLivesInCity(StructuredRel):
- some_num = IntegerProperty(index=True,
- default=12)
-
- class CountryOfOrigin(StructuredNode):
- code = StringProperty(unique_index=True,
- required=True)
-
- class CityOfResidence(StructuredNode):
- name = StringProperty(required=True)
- country = RelationshipTo(CountryOfOrigin,
- 'FROM_COUNTRY')
-
- class PersonOfInterest(StructuredNode):
- uid = UniqueIdProperty()
- name = StringProperty(unique_index=True)
- age = IntegerProperty(index=True,
- default=0)
-
- country = RelationshipTo(CountryOfOrigin,
- 'IS_FROM')
- city = RelationshipTo(CityOfResidence,
- 'LIVES_IN',
- model=PersonLivesInCity)
-
-Then, paths can be retrieved with:
-
-::
-
- q = db.cypher_query("MATCH p=(:CityOfResidence)<-[:LIVES_IN]-(:PersonOfInterest)-[:IS_FROM]->(:CountryOfOrigin) RETURN p LIMIT 1",
- resolve_objects = True)
-
-Notice here that ``resolve_objects`` is set to ``True``. This results in ``q`` being a
-list of ``result, result_name`` and ``q[0][0][0]`` being a ``NeomodelPath`` object.
-
-``NeomodelPath`` ``nodes, relationships`` attributes contain already instantiated objects of the
-nodes and relationships in the query, *in order of appearance*.
-
-It would be particularly useful to note here that each object is read exactly once from
-the database. Therefore, nodes will be instantiated to their neomodel node objects and
-relationships to their relationship models *if such a model exists*. In other words,
-relationships with data (such as ``PersonLivesInCity`` above) will be instantiated to their
-respective objects or ``StrucuredRel`` otherwise. Relationships do not "reload" their
-end-points (unless this is required).
-
-Async neomodel - Caveats
-========================
-
-Python does not support async dunder methods. This means that we had to implement some overrides for those.
-See the example below::
-
- # This will not work as it uses the synchronous __bool__ method
- assert await Customer.nodes.filter(prop="value")
-
- # Do this instead
- assert await Customer.nodes.filter(prop="value").check_bool()
- assert await Customer.nodes.filter(prop="value").check_nonzero()
-
- # Note : no changes are needed for sync so this still works :
- assert Customer.nodes.filter(prop="value")
diff --git a/doc/source/traversal.rst b/doc/source/traversal.rst
new file mode 100644
index 00000000..c6347ad0
--- /dev/null
+++ b/doc/source/traversal.rst
@@ -0,0 +1,96 @@
+.. _Path traversal:
+
+==============
+Path traversal
+==============
+
+Neo4j is about traversing the graph, which means leveraging nodes and relations between them. This section will show you how to traverse the graph using neomodel.
+
+We will cover two methods : `traverse_relations` and `fetch_relations`. Those two methods are *mutually exclusive*, so you cannot chain them.
+
+For the examples in this section, we will be using the following model::
+
+ class Country(StructuredNode):
+ country_code = StringProperty(unique_index=True)
+ name = StringProperty()
+
+ class Supplier(StructuredNode):
+ name = StringProperty()
+ delivery_cost = IntegerProperty()
+ country = RelationshipTo(Country, 'ESTABLISHED_IN')
+
+ class Coffee(StructuredNode):
+ name = StringProperty(unique_index=True)
+ price = IntegerProperty()
+ suppliers = RelationshipFrom(Supplier, 'SUPPLIES')
+
+Traverse relations
+------------------
+
+The `traverse_relations` method allows you to filter on the existence of more complex traversals. For example, to find all `Coffee` nodes that have a supplier, and retrieve the country of that supplier, you can do::
+
+ Coffee.nodes.traverse_relations(country='suppliers__country').all()
+
+This will generate a Cypher MATCH clause that enforces the existence of at least one path like `Coffee<--Supplier-->Country`.
+
+The `Country` nodes matched will be made available for the rest of the query, with the variable name `country`. Note that this aliasing is optional. See :ref:`Advanced query operations` for examples of how to use this aliasing.
+
+.. note::
+
+ The `traverse_relations` method can be used to traverse multiple relationships, like::
+
+ Coffee.nodes.traverse_relations('suppliers__country', 'pub__city').all()
+
+ This will generate a Cypher MATCH clause that enforces the existence of at least one path like `Coffee<--Supplier-->Country` and `Coffee<--Pub-->City`.
+
+Fetch relations
+---------------
+
+The syntax for `fetch_relations` is similar to `traverse_relations`, except that the generated Cypher will return all traversed objects (nodes and relations)::
+
+ Coffee.nodes.fetch_relations(country='suppliers__country').all()
+
+.. note::
+
+ Any relationship that you intend to traverse using this method **MUST have a model defined**, even if only the default StructuredRel, like::
+
+ class Person(StructuredNode):
+ country = RelationshipTo(Country, 'IS_FROM', model=StructuredRel)
+
+ Otherwise, neomodel will not be able to determine which relationship model to resolve into, and will fail.
+
+Optional match
+--------------
+
+With both `traverse_relations` and `fetch_relations`, you can force the use of an ``OPTIONAL MATCH`` statement using the following syntax::
+
+ from neomodel.match import Optional
+
+ # Return the Person nodes, and if they have suppliers, return the suppliers as well
+ results = Coffee.nodes.fetch_relations(Optional('suppliers')).all()
+
+.. note::
+
+ You can fetch one or more relations within the same call
+ to `.fetch_relations()` and you can mix optional and non-optional
+ relations, like::
+
+ Person.nodes.fetch_relations('city__country', Optional('country')).all()
+
+Resolve results
+---------------
+
+By default, `fetch_relations` will return a list of tuples. If your path looks like ``(startNode:Coffee)<-[r1]-(middleNode:Supplier)-[r2]->(endNode:Country)``,
+then you will get a list of results, where each result is a list of ``(startNode, r1, middleNode, r2, endNode)``.
+These will be resolved by neomodel, so ``startNode`` will be a ``Coffee`` class as defined in neomodel for example.
+
+Using the `resolve_subgraph` method, you can get instead a list of "subgraphs", where each returned `StructuredNode` element will contain its relations and neighbour nodes. For example::
+
+ results = Coffee.nodes.fetch_relations('suppliers__country').resolve_subgraph().all()
+
+In this example, `results[0]` will be a `Coffee` object, with a `_relations` attribute. This will in turn have a `suppliers` and a `suppliers_relationship` attribute, which will contain the `Supplier` object and the relation object respectively. Recursively, the `Supplier` object will have a `country` attribute, which will contain the `Country` object.
+
+.. note::
+
+ The `resolve_subgraph` method is only available for `fetch_relations` queries. This is because `traverse_relations` queries do not return any relations, and thus there is no need to resolve them.
+
diff --git a/docker-scripts/docker-neo4j.sh b/docker-scripts/docker-neo4j.sh
index 99aabfff..6b146c95 100644
--- a/docker-scripts/docker-neo4j.sh
+++ b/docker-scripts/docker-neo4j.sh
@@ -5,4 +5,5 @@ docker run \
--env NEO4J_AUTH=neo4j/foobarbaz \
--env NEO4J_ACCEPT_LICENSE_AGREEMENT=yes \
--env NEO4JLABS_PLUGINS='["apoc"]' \
- neo4j:$1
\ No newline at end of file
+ --rm \
+ neo4j:$1
diff --git a/neomodel/_version.py b/neomodel/_version.py
index d2f4a6f4..fc30498f 100644
--- a/neomodel/_version.py
+++ b/neomodel/_version.py
@@ -1 +1 @@
-__version__ = "5.3.3"
+__version__ = "5.4.0"
diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py
index 7f0435fe..6a2f9933 100644
--- a/neomodel/async_/match.py
+++ b/neomodel/async_/match.py
@@ -1,15 +1,21 @@
import inspect
import re
-from collections import defaultdict
+import string
from dataclasses import dataclass
-from typing import Optional
+from typing import Any, Dict, List
+from typing import Optional as TOptional
+from typing import Tuple, Union
+from neomodel.async_ import relationship_manager
from neomodel.async_.core import AsyncStructuredNode, adb
+from neomodel.async_.relationship import AsyncStructuredRel
from neomodel.exceptions import MultipleNodesReturned
from neomodel.match_q import Q, QBase
-from neomodel.properties import AliasProperty, ArrayProperty
+from neomodel.properties import AliasProperty, ArrayProperty, Property
from neomodel.util import INCOMING, OUTGOING
+CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR = re.compile(r"(?i:MERGE|CREATE|DELETE|DETACH)")
+
def _rel_helper(
lhs,
@@ -194,6 +200,8 @@ def _rel_merge_helper(
# add all regex operators
OPERATOR_TABLE.update(_REGEX_OPERATOR_TABLE)
+path_split_regex = re.compile(r"__(?!_)|\|")
+
def install_traversals(cls, node_set):
"""
@@ -213,136 +221,122 @@ def install_traversals(cls, node_set):
setattr(node_set, key, traversal)
-def process_filter_args(cls, kwargs):
- """
- loop through properties in filter parameters check they match class definition
- deflate them and convert into something easy to generate cypher from
- """
-
- output = {}
-
- for key, value in kwargs.items():
- if "__" in key:
- prop, operator = key.rsplit("__")
- operator = OPERATOR_TABLE[operator]
- else:
- prop = key
- operator = "="
-
- if prop not in cls.defined_properties(rels=False):
+def _handle_special_operators(
+ property_obj: Property, key: str, value: str, operator: str, prop: str
+) -> Tuple[str, str, str]:
+ if operator == _SPECIAL_OPERATOR_IN:
+ if not isinstance(value, (list, tuple)):
raise ValueError(
- f"No such property {prop} on {cls.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation."
+ f"Value must be a tuple or list for IN operation {key}={value}"
)
-
- property_obj = getattr(cls, prop)
- if isinstance(property_obj, AliasProperty):
- prop = property_obj.aliased_to()
- deflated_value = getattr(cls, prop).deflate(value)
+ if isinstance(property_obj, ArrayProperty):
+ deflated_value = property_obj.deflate(value)
+ operator = _SPECIAL_OPERATOR_ARRAY_IN
else:
- operator, deflated_value = transform_operator_to_filter(
- operator=operator,
- filter_key=key,
- filter_value=value,
- property_obj=property_obj,
- )
-
- # map property to correct property name in the database
- db_property = cls.defined_properties(rels=False)[prop].get_db_property_name(
- prop
- )
-
- output[db_property] = (operator, deflated_value)
+ deflated_value = [property_obj.deflate(v) for v in value]
+ elif operator == _SPECIAL_OPERATOR_ISNULL:
+ if not isinstance(value, bool):
+ raise ValueError(f"Value must be a bool for isnull operation on {key}")
+ operator = "IS NULL" if value else "IS NOT NULL"
+ deflated_value = None
+ elif operator in _REGEX_OPERATOR_TABLE.values():
+ deflated_value = property_obj.deflate(value)
+ if not isinstance(deflated_value, str):
+ raise ValueError(f"Must be a string value for {key}")
+ if operator in _STRING_REGEX_OPERATOR_TABLE.values():
+ deflated_value = re.escape(deflated_value)
+ deflated_value = operator.format(deflated_value)
+ operator = _SPECIAL_OPERATOR_REGEX
+ else:
+ deflated_value = property_obj.deflate(value)
- return output
+ return deflated_value, operator, prop
-def transform_in_operator_to_filter(operator, filter_key, filter_value, property_obj):
- """
- Transform in operator to a cypher filter
- Args:
- operator (str): operator to transform
- filter_key (str): filter key
- filter_value (str): filter value
- property_obj (object): property object
- Returns:
- tuple: operator, deflated_value
- """
- if not isinstance(filter_value, tuple) and not isinstance(filter_value, list):
- raise ValueError(
- f"Value must be a tuple or list for IN operation {filter_key}={filter_value}"
- )
- if isinstance(property_obj, ArrayProperty):
- deflated_value = property_obj.deflate(filter_value)
- operator = _SPECIAL_OPERATOR_ARRAY_IN
+def _deflate_value(
+ cls, property_obj: Property, key: str, value: str, operator: str, prop: str
+) -> Tuple[str, str, str]:
+ if isinstance(property_obj, AliasProperty):
+ prop = property_obj.aliased_to()
+ deflated_value = getattr(cls, prop).deflate(value)
else:
- deflated_value = [property_obj.deflate(v) for v in filter_value]
+ # handle special operators
+ deflated_value, operator, prop = _handle_special_operators(
+ property_obj, key, value, operator, prop
+ )
- return operator, deflated_value
+ return deflated_value, operator, prop
+
+
+def _initialize_filter_args_variables(cls, key: str):
+ current_class = cls
+ current_rel_model = None
+ leaf_prop = None
+ operator = "="
+ is_rel_property = "|" in key
+ prop = key
+
+ return current_class, current_rel_model, leaf_prop, operator, is_rel_property, prop
+
+
+def _process_filter_key(cls, key: str) -> Tuple[Property, str, str]:
+ (
+ current_class,
+ current_rel_model,
+ leaf_prop,
+ operator,
+ is_rel_property,
+ prop,
+ ) = _initialize_filter_args_variables(cls, key)
+
+ for part in re.split(path_split_regex, key):
+ defined_props = current_class.defined_properties(rels=True)
+ # update defined props dictionary with relationship properties if
+ # we are filtering by property
+ if is_rel_property and current_rel_model:
+ defined_props.update(current_rel_model.defined_properties(rels=True))
+ if part in defined_props:
+ if isinstance(
+ defined_props[part], relationship_manager.AsyncRelationshipDefinition
+ ):
+ defined_props[part].lookup_node_class()
+ current_class = defined_props[part].definition["node_class"]
+ current_rel_model = defined_props[part].definition["model"]
+ elif part in OPERATOR_TABLE:
+ operator = OPERATOR_TABLE[part]
+ prop, _ = prop.rsplit("__", 1)
+ continue
+ else:
+ raise ValueError(
+ f"No such property {part} on {cls.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation."
+ )
+ leaf_prop = part
+ if is_rel_property and current_rel_model:
+ property_obj = getattr(current_rel_model, leaf_prop)
+ else:
+ property_obj = getattr(current_class, leaf_prop)
-def transform_null_operator_to_filter(filter_key, filter_value):
- """
- Transform null operator to a cypher filter
- Args:
- filter_key (str): filter key
- filter_value (str): filter value
- Returns:
- tuple: operator, deflated_value
- """
- if not isinstance(filter_value, bool):
- raise ValueError(f"Value must be a bool for isnull operation on {filter_key}")
- operator = "IS NULL" if filter_value else "IS NOT NULL"
- deflated_value = None
- return operator, deflated_value
+ return property_obj, operator, prop
-def transform_regex_operator_to_filter(
- operator, filter_key, filter_value, property_obj
-):
+def process_filter_args(cls, kwargs) -> Dict:
"""
- Transform regex operator to a cypher filter
- Args:
- operator (str): operator to transform
- filter_key (str): filter key
- filter_value (str): filter value
- property_obj (object): property object
- Returns:
- tuple: operator, deflated_value
+ loop through properties in filter parameters check they match class definition
+ deflate them and convert into something easy to generate cypher from
"""
+ output = {}
- deflated_value = property_obj.deflate(filter_value)
- if not isinstance(deflated_value, str):
- raise ValueError(f"Must be a string value for {filter_key}")
- if operator in _STRING_REGEX_OPERATOR_TABLE.values():
- deflated_value = re.escape(deflated_value)
- deflated_value = operator.format(deflated_value)
- operator = _SPECIAL_OPERATOR_REGEX
- return operator, deflated_value
-
-
-def transform_operator_to_filter(operator, filter_key, filter_value, property_obj):
- if operator == _SPECIAL_OPERATOR_IN:
- operator, deflated_value = transform_in_operator_to_filter(
- operator=operator,
- filter_key=filter_key,
- filter_value=filter_value,
- property_obj=property_obj,
- )
- elif operator == _SPECIAL_OPERATOR_ISNULL:
- operator, deflated_value = transform_null_operator_to_filter(
- filter_key=filter_key, filter_value=filter_value
- )
- elif operator in _REGEX_OPERATOR_TABLE.values():
- operator, deflated_value = transform_regex_operator_to_filter(
- operator=operator,
- filter_key=filter_key,
- filter_value=filter_value,
- property_obj=property_obj,
+ for key, value in kwargs.items():
+ property_obj, operator, prop = _process_filter_key(cls, key)
+ deflated_value, operator, prop = _deflate_value(
+ cls, property_obj, key, value, operator, prop
)
- else:
- deflated_value = property_obj.deflate(filter_value)
+ # map property to correct property name in the database
+ db_property = prop
- return operator, deflated_value
+ output[db_property] = (operator, deflated_value)
+ return output
def process_has_args(cls, kwargs):
@@ -374,34 +368,34 @@ def process_has_args(cls, kwargs):
class QueryAST:
- match: Optional[list]
- optional_match: Optional[list]
- where: Optional[list]
- with_clause: Optional[str]
- return_clause: Optional[str]
- order_by: Optional[str]
- skip: Optional[int]
- limit: Optional[int]
- result_class: Optional[type]
- lookup: Optional[str]
- additional_return: Optional[list]
- is_count: Optional[bool]
+ match: List[str]
+ optional_match: List[str]
+ where: List[str]
+ with_clause: TOptional[str]
+ return_clause: TOptional[str]
+ order_by: TOptional[List[str]]
+ skip: TOptional[int]
+ limit: TOptional[int]
+ result_class: TOptional[type]
+ lookup: TOptional[str]
+ additional_return: List[str]
+ is_count: TOptional[bool]
def __init__(
self,
- match: Optional[list] = None,
- optional_match: Optional[list] = None,
- where: Optional[list] = None,
- with_clause: Optional[str] = None,
- return_clause: Optional[str] = None,
- order_by: Optional[str] = None,
- skip: Optional[int] = None,
- limit: Optional[int] = None,
- result_class: Optional[type] = None,
- lookup: Optional[str] = None,
- additional_return: Optional[list] = None,
- is_count: Optional[bool] = False,
- ):
+ match: TOptional[List[str]] = None,
+ optional_match: TOptional[List[str]] = None,
+ where: TOptional[List[str]] = None,
+ with_clause: TOptional[str] = None,
+ return_clause: TOptional[str] = None,
+ order_by: TOptional[List[str]] = None,
+ skip: TOptional[int] = None,
+ limit: TOptional[int] = None,
+ result_class: TOptional[type] = None,
+ lookup: TOptional[str] = None,
+ additional_return: TOptional[List[str]] = None,
+ is_count: TOptional[bool] = False,
+ ) -> None:
self.match = match if match else []
self.optional_match = optional_match if optional_match else []
self.where = where if where else []
@@ -414,18 +408,19 @@ def __init__(
self.lookup = lookup
self.additional_return = additional_return if additional_return else []
self.is_count = is_count
+ self.subgraph: Dict = {}
class AsyncQueryBuilder:
- def __init__(self, node_set):
+ def __init__(self, node_set, subquery_context: bool = False) -> None:
self.node_set = node_set
self._ast = QueryAST()
- self._query_params = {}
- self._place_holder_registry = {}
- self._ident_count = 0
- self._node_counters = defaultdict(int)
+ self._query_params: Dict = {}
+ self._place_holder_registry: Dict = {}
+ self._ident_count: int = 0
+ self._subquery_context: bool = subquery_context
- async def build_ast(self):
+ async def build_ast(self) -> "AsyncQueryBuilder":
if hasattr(self.node_set, "relations_to_fetch"):
for relation in self.node_set.relations_to_fetch:
self.build_traversal_from_path(relation, self.node_set.source)
@@ -439,7 +434,7 @@ async def build_ast(self):
return self
- async def build_source(self, source):
+ async def build_source(self, source) -> str:
if isinstance(source, AsyncTraversal):
return await self.build_traversal(source)
if isinstance(source, AsyncNodeSet):
@@ -468,18 +463,40 @@ async def build_source(self, source):
return await self.build_node(source)
raise ValueError("Unknown source type " + repr(source))
- def create_ident(self):
+ def create_ident(self) -> str:
self._ident_count += 1
- return "r" + str(self._ident_count)
+ return f"r{self._ident_count}"
- def build_order_by(self, ident, source):
+ def build_order_by(self, ident: str, source: "AsyncNodeSet") -> None:
if "?" in source.order_by_elements:
self._ast.with_clause = f"{ident}, rand() as r"
- self._ast.order_by = "r"
+ self._ast.order_by = ["r"]
else:
- self._ast.order_by = [f"{ident}.{p}" for p in source.order_by_elements]
+ order_by = []
+ for elm in source.order_by_elements:
+ if isinstance(elm, RawCypher):
+ order_by.append(elm.render({"n": ident}))
+ continue
+ is_rel_property = "|" in elm
+ if "__" not in elm and not is_rel_property:
+ prop = elm.split(" ")[0] if " " in elm else elm
+ if prop not in source.source_class.defined_properties(rels=False):
+ raise ValueError(
+ f"No such property {prop} on {source.source_class.__name__}. "
+ f"Note that Neo4j internals like id or element_id are not allowed "
+ f"for use in this operation."
+ )
+ order_by.append(f"{ident}.{elm}")
+ else:
+ path, prop = elm.rsplit("__" if not is_rel_property else "|", 1)
+ result = self.lookup_query_variable(
+ path, return_relation=is_rel_property
+ )
+ if result:
+ order_by.append(f"{result[0]}.{prop}")
+ self._ast.order_by = order_by
- async def build_traversal(self, traversal):
+ async def build_traversal(self, traversal) -> str:
"""
traverse a relationship from a node to a set of nodes
"""
@@ -507,27 +524,29 @@ async def build_traversal(self, traversal):
return traversal_ident
- def _additional_return(self, name):
+ def _additional_return(self, name: str):
if name not in self._ast.additional_return and name != self._ast.return_clause:
self._ast.additional_return.append(name)
- def build_traversal_from_path(self, relation: dict, source_class) -> str:
+ def build_traversal_from_path(
+ self, relation: dict, source_class
+ ) -> Tuple[str, Any]:
path: str = relation["path"]
stmt: str = ""
source_class_iterator = source_class
- for index, part in enumerate(path.split("__")):
+ parts = re.split(path_split_regex, path)
+ subgraph = self._ast.subgraph
+ rel_iterator: str = ""
+ already_present = False
+ existing_rhs_name = ""
+ for index, part in enumerate(parts):
relationship = getattr(source_class_iterator, part)
+ if rel_iterator:
+ rel_iterator += "__"
+ rel_iterator += part
# build source
if "node_class" not in relationship.definition:
relationship.lookup_node_class()
- rhs_label = relationship.definition["node_class"].__label__
- rel_reference = f'{relationship.definition["node_class"]}_{part}'
- self._node_counters[rel_reference] += 1
- rhs_name = (
- f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}"
- )
- rhs_ident = f"{rhs_name}:{rhs_label}"
- self._additional_return(rhs_name)
if not stmt:
lhs_label = source_class_iterator.__label__
lhs_name = lhs_label.lower()
@@ -537,13 +556,44 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str:
# contains the primary node so _contains() works
# as usual
self._ast.return_clause = lhs_name
- else:
+ if self._subquery_context:
+ # Don't include label in identifier if we are in a subquery
+ lhs_ident = lhs_name
+ elif relation["include_in_return"]:
self._additional_return(lhs_name)
else:
lhs_ident = stmt
+ already_present = part in subgraph
rel_ident = self.create_ident()
- self._additional_return(rel_ident)
+ rhs_label = relationship.definition["node_class"].__label__
+ if relation.get("relation_filtering"):
+ rhs_name = rel_ident
+ else:
+ if index + 1 == len(parts) and "alias" in relation:
+ # If an alias is defined, use it to store the last hop in the path
+ rhs_name = relation["alias"]
+ else:
+ rhs_name = f"{rhs_label.lower()}_{rel_iterator}"
+ rhs_ident = f"{rhs_name}:{rhs_label}"
+ if relation["include_in_return"] and not already_present:
+ self._additional_return(rhs_name)
+
+ if not already_present:
+ subgraph[part] = {
+ "target": relationship.definition["node_class"],
+ "children": {},
+ "variable_name": rhs_name,
+ "rel_variable_name": rel_ident,
+ }
+ else:
+ existing_rhs_name = subgraph[part][
+ "rel_variable_name"
+ if relation.get("relation_filtering")
+ else "variable_name"
+ ]
+ if relation["include_in_return"] and not already_present:
+ self._additional_return(rel_ident)
stmt = _rel_helper(
lhs=lhs_ident,
rhs=rhs_ident,
@@ -552,12 +602,16 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str:
relation_type=relationship.definition["relation_type"],
)
source_class_iterator = relationship.definition["node_class"]
+ subgraph = subgraph[part]["children"]
- if relation.get("optional"):
- self._ast.optional_match.append(stmt)
- else:
- self._ast.match.append(stmt)
- return rhs_name
+ if not already_present:
+ if relation.get("optional"):
+ self._ast.optional_match.append(stmt)
+ else:
+ self._ast.match.append(stmt)
+ return rhs_name, relationship.definition["node_class"]
+
+ return existing_rhs_name, relationship.definition["node_class"]
async def build_node(self, node):
ident = node.__class__.__name__.lower()
@@ -573,7 +627,7 @@ async def build_node(self, node):
self._ast.result_class = node.__class__
return ident
- def build_label(self, ident, cls):
+ def build_label(self, ident, cls) -> str:
"""
match nodes by a label
"""
@@ -609,14 +663,71 @@ def build_additional_match(self, ident, node_set):
else:
raise ValueError("Expecting dict got: " + repr(val))
- def _register_place_holder(self, key):
+ def _register_place_holder(self, key: str) -> str:
if key in self._place_holder_registry:
self._place_holder_registry[key] += 1
else:
self._place_holder_registry[key] = 1
return key + "_" + str(self._place_holder_registry[key])
- def _parse_q_filters(self, ident, q, source_class):
+ def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str, Any]:
+ is_rel_filter = "|" in prop
+ if is_rel_filter:
+ path, prop = prop.rsplit("|", 1)
+ else:
+ path, prop = prop.rsplit("__", 1)
+ result = self.lookup_query_variable(path, return_relation=is_rel_filter)
+ if not result:
+ ident, target_class = self.build_traversal_from_path(
+ {
+ "path": path,
+ "include_in_return": True,
+ "relation_filtering": is_rel_filter,
+ },
+ source_class,
+ )
+ else:
+ ident, target_class = result
+ return ident, path, prop, target_class
+
+ def _finalize_filter_statement(
+ self, operator: str, ident: str, prop: str, val: Any
+ ) -> str:
+ if operator in _UNARY_OPERATORS:
+ # unary operators do not have a parameter
+ statement = f"{ident}.{prop} {operator}"
+ else:
+ place_holder = self._register_place_holder(ident + "_" + prop)
+ if operator == _SPECIAL_OPERATOR_ARRAY_IN:
+ statement = operator.format(
+ ident=ident,
+ prop=prop,
+ val=f"${place_holder}",
+ )
+ else:
+ statement = f"{ident}.{prop} {operator} ${place_holder}"
+ self._query_params[place_holder] = val
+
+ return statement
+
+ def _build_filter_statements(
+ self, ident: str, filters, target: List[str], source_class
+ ) -> None:
+ for prop, op_and_val in filters.items():
+ path = None
+ is_rel_filter = "|" in prop
+ target_class = source_class
+ if "__" in prop or is_rel_filter:
+ ident, path, prop, target_class = self._parse_path(source_class, prop)
+ operator, val = op_and_val
+ if not is_rel_filter:
+ prop = target_class.defined_properties(rels=False)[
+ prop
+ ].get_db_property_name(prop)
+ statement = self._finalize_filter_statement(operator, ident, prop, val)
+ target.append(statement)
+
+ def _parse_q_filters(self, ident, q, source_class) -> str:
target = []
for child in q.children:
if isinstance(child, QBase):
@@ -627,36 +738,22 @@ def _parse_q_filters(self, ident, q, source_class):
else:
kwargs = {child[0]: child[1]}
filters = process_filter_args(source_class, kwargs)
- for prop, op_and_val in filters.items():
- operator, val = op_and_val
- if operator in _UNARY_OPERATORS:
- # unary operators do not have a parameter
- statement = f"{ident}.{prop} {operator}"
- else:
- place_holder = self._register_place_holder(ident + "_" + prop)
- if operator == _SPECIAL_OPERATOR_ARRAY_IN:
- statement = operator.format(
- ident=ident,
- prop=prop,
- val=f"${place_holder}",
- )
- else:
- statement = f"{ident}.{prop} {operator} ${place_holder}"
- self._query_params[place_holder] = val
- target.append(statement)
+ self._build_filter_statements(ident, filters, target, source_class)
ret = f" {q.connector} ".join(target)
if q.negated:
ret = f"NOT ({ret})"
return ret
- def build_where_stmt(self, ident, filters, q_filters=None, source_class=None):
+ def build_where_stmt(
+ self, ident: str, filters, q_filters=None, source_class=None
+ ) -> None:
"""
construct a where statement from some filters
"""
if q_filters is not None:
- stmts = self._parse_q_filters(ident, q_filters, source_class)
- if stmts:
- self._ast.where.append(stmts)
+ stmt = self._parse_q_filters(ident, q_filters, source_class)
+ if stmt:
+ self._ast.where.append(stmt)
else:
stmts = []
for row in filters:
@@ -682,8 +779,37 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None):
self._ast.where.append(" AND ".join(stmts))
- def build_query(self):
- query = ""
+ def lookup_query_variable(
+ self, path: str, return_relation: bool = False
+ ) -> TOptional[Tuple[str, Any]]:
+ """Retrieve the variable name generated internally for the given traversal path."""
+ subgraph = self._ast.subgraph
+ if not subgraph:
+ return None
+ traversals = re.split(path_split_regex, path)
+ if len(traversals) == 0:
+ raise ValueError("Can only lookup traversal variables")
+ if traversals[0] not in subgraph:
+ return None
+ subgraph = subgraph[traversals[0]]
+ if len(traversals) == 1:
+ variable_to_return = f"{subgraph['rel_variable_name' if return_relation else 'variable_name']}"
+ return variable_to_return, subgraph["target"]
+ variable_to_return = ""
+ last_property = traversals[-1]
+ for part in traversals[1:]:
+ child = subgraph["children"].get(part)
+ if not child:
+ return None
+ subgraph = child
+ if part == last_property:
+ # if last part of prop is the last traversal
+ # we are safe to lookup the variable from the query
+ variable_to_return = f"{subgraph['rel_variable_name' if return_relation else 'variable_name']}"
+ return variable_to_return, subgraph["target"]
+
+ def build_query(self) -> str:
+ query: str = ""
if self._ast.lookup:
query += self._ast.lookup
@@ -702,6 +828,9 @@ def build_query(self):
query += " OPTIONAL MATCH ".join(i for i in self._ast.optional_match)
if self._ast.where:
+ if self._ast.optional_match:
+ # Make sure filtering works as expected with optional match, even if it's not performant...
+ query += " WITH *"
query += " WHERE "
query += " AND ".join(self._ast.where)
@@ -709,13 +838,85 @@ def build_query(self):
query += " WITH "
query += self._ast.with_clause
+ if hasattr(self.node_set, "_intermediate_transforms"):
+ for transform in self.node_set._intermediate_transforms:
+ query += " WITH "
+ injected_vars: list = []
+ # Reset return list since we'll probably invalidate most variables
+ self._ast.return_clause = ""
+ self._ast.additional_return = []
+ for name, source in transform["vars"].items():
+ if type(source) is str:
+ injected_vars.append(f"{source} AS {name}")
+ elif isinstance(source, RelationNameResolver):
+ result = self.lookup_query_variable(
+ source.relation, return_relation=True
+ )
+ if not result:
+ raise ValueError(
+ f"Unable to resolve variable name for relation {source.relation}."
+ )
+ injected_vars.append(f"{result[0]} AS {name}")
+ elif isinstance(source, NodeNameResolver):
+ result = self.lookup_query_variable(source.node)
+ if not result:
+ raise ValueError(
+ f"Unable to resolve variable name for node {source.node}."
+ )
+ injected_vars.append(f"{result[0]} AS {name}")
+ query += ",".join(injected_vars)
+ if not transform["ordering"]:
+ continue
+ query += " ORDER BY "
+ ordering: list = []
+ for item in transform["ordering"]:
+ if isinstance(item, RawCypher):
+ ordering.append(item.render({}))
+ continue
+ if item.startswith("-"):
+ ordering.append(f"{item[1:]} DESC")
+ else:
+ ordering.append(item)
+ query += ",".join(ordering)
+
+ returned_items: list[str] = []
+ if hasattr(self.node_set, "_subqueries"):
+ for subquery, return_set in self.node_set._subqueries:
+ outer_primary_var = self._ast.return_clause
+ query += f" CALL {{ WITH {outer_primary_var} {subquery} }} "
+ for varname in return_set:
+ # We declare the returned variables as "virtual" relations of the
+ # root node class to make sure they will be translated by a call to
+ # resolve_subgraph() (otherwise, they will be lost).
+ # This is probably a temporary solution until we find something better...
+ self._ast.subgraph[varname] = {
+ "target": None, # We don't need target class in this use case
+ "children": {},
+ "variable_name": varname,
+ "rel_variable_name": varname,
+ }
+ returned_items += return_set
+
query += " RETURN "
- if self._ast.return_clause:
- query += self._ast.return_clause
+ if self._ast.return_clause and not self._subquery_context:
+ returned_items.append(self._ast.return_clause)
if self._ast.additional_return:
- if self._ast.return_clause:
- query += ", "
- query += ", ".join(self._ast.additional_return)
+ returned_items += self._ast.additional_return
+ if hasattr(self.node_set, "_extra_results"):
+ for props in self.node_set._extra_results:
+ leftpart = props["vardef"].render(self)
+ varname = (
+ props["alias"]
+ if props.get("alias")
+ else props["vardef"].get_internal_name()
+ )
+ if varname in returned_items:
+ # We're about to override an existing variable, delete it first to
+ # avoid duplicate error
+ returned_items.remove(varname)
+ returned_items.append(f"{leftpart} AS {varname}")
+
+ query += ", ".join(returned_items)
if self._ast.order_by:
query += " ORDER BY "
@@ -754,7 +955,6 @@ async def _count(self):
async def _contains(self, node_element_id):
# inject id = into ast
if not self._ast.return_clause:
- print(self._ast.additional_return)
self._ast.return_clause = self._ast.additional_return[0]
ident = self._ast.return_clause
place_holder = self._register_place_holder(ident + "_contains")
@@ -764,7 +964,7 @@ async def _contains(self, node_element_id):
self._query_params[place_holder] = node_element_id
return await self._count() >= 1
- async def _execute(self, lazy=False):
+ async def _execute(self, lazy: bool = False, dict_output: bool = False):
if lazy:
# inject id() into return or return_set
if self._ast.return_clause:
@@ -777,9 +977,13 @@ async def _execute(self, lazy=False):
for item in self._ast.additional_return
]
query = self.build_query()
- results, _ = await adb.cypher_query(
+ results, prop_names = await adb.cypher_query(
query, self._query_params, resolve_objects=True
)
+ if dict_output:
+ for item in results:
+ yield dict(zip(prop_names, item))
+ return
# The following is not as elegant as it could be but had to be copied from the
# version prior to cypher_query with the resolve_objects capability.
# It seems that certain calls are only supposed to be focusing to the first
@@ -800,6 +1004,7 @@ class AsyncBaseSet:
"""
query_cls = AsyncQueryBuilder
+ source_class: AsyncStructuredNode
async def all(self, lazy=False):
"""
@@ -823,7 +1028,7 @@ async def get_len(self):
ast = await self.query_cls(self).build_ast()
return await ast._count()
- async def check_bool(self):
+ async def check_bool(self) -> bool:
"""
Override for __bool__ dunder method.
:return: True if the set contains any nodes, False otherwise
@@ -833,7 +1038,7 @@ async def check_bool(self):
_count = await ast._count()
return _count > 0
- async def check_nonzero(self):
+ async def check_nonzero(self) -> bool:
"""
Override for __bool__ dunder method.
:return: True if the set contains any node, False otherwise
@@ -881,12 +1086,126 @@ class Optional:
relation: str
+@dataclass
+class RelationNameResolver:
+ """Helper to refer to a relation variable name.
+
+ Since variable names are generated automatically within MATCH statements (for
+ anything injected using fetch_relations or traverse_relations), we need a way to
+ retrieve them.
+
+ """
+
+ relation: str
+
+
+@dataclass
+class NodeNameResolver:
+ """Helper to refer to a node variable name.
+
+ Since variable names are generated automatically within MATCH statements (for
+ anything injected using fetch_relations or traverse_relations), we need a way to
+ retrieve them.
+
+ """
+
+ node: str
+
+
+@dataclass
+class BaseFunction:
+ input_name: Union[str, "BaseFunction", NodeNameResolver, RelationNameResolver]
+
+ def __post_init__(self) -> None:
+ self._internal_name: str = ""
+
+ def get_internal_name(self) -> str:
+ return self._internal_name
+
+ def resolve_internal_name(self, qbuilder: AsyncQueryBuilder) -> str:
+ if isinstance(self.input_name, NodeNameResolver):
+ result = qbuilder.lookup_query_variable(self.input_name.node)
+ elif isinstance(self.input_name, RelationNameResolver):
+ result = qbuilder.lookup_query_variable(self.input_name.relation, True)
+ else:
+ result = (str(self.input_name), None)
+ if result is None:
+ raise ValueError(f"Unknown variable {self.input_name} used in Collect()")
+ self._internal_name = result[0]
+ return self._internal_name
+
+ def render(self, qbuilder: AsyncQueryBuilder) -> str:
+ raise NotImplementedError
+
+
+@dataclass
+class AggregatingFunction(BaseFunction):
+ """Base aggregating function class."""
+
+ pass
+
+
+@dataclass
+class Collect(AggregatingFunction):
+ """collect() function."""
+
+ distinct: bool = False
+
+ def render(self, qbuilder: AsyncQueryBuilder) -> str:
+ varname = self.resolve_internal_name(qbuilder)
+ if self.distinct:
+ return f"collect(DISTINCT {varname})"
+ return f"collect({varname})"
+
+
+@dataclass
+class ScalarFunction(BaseFunction):
+ """Base scalar function class."""
+
+ pass
+
+
+@dataclass
+class Last(ScalarFunction):
+ """last() function."""
+
+ def render(self, qbuilder: AsyncQueryBuilder) -> str:
+ if isinstance(self.input_name, str):
+ content = str(self.input_name)
+ elif isinstance(self.input_name, BaseFunction):
+ content = self.input_name.render(qbuilder)
+ self._internal_name = self.input_name.get_internal_name()
+ else:
+ content = self.resolve_internal_name(qbuilder)
+ return f"last({content})"
+
+
+@dataclass
+class RawCypher:
+ """Helper to inject raw cypher statement.
+
+ Can be used in order_by() call for example.
+
+ """
+
+ statement: str
+
+ def __post_init__(self):
+ if CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR.search(self.statement):
+ raise ValueError(
+ "RawCypher: Do not include any action that has side effect"
+ )
+
+ def render(self, context: Dict) -> str:
+ return string.Template(self.statement).substitute(context)
+
+
class AsyncNodeSet(AsyncBaseSet):
"""
A class representing as set of nodes matching common query parameters
"""
- def __init__(self, source):
+ def __init__(self, source) -> None:
self.source = source # could be a Traverse object or a node class
if isinstance(source, AsyncTraversal):
self.source_class = source.target_class
@@ -900,14 +1219,18 @@ def __init__(self, source):
# setup Traversal objects using relationship definitions
install_traversals(self.source_class, self)
- self.filters = []
+ self.filters: List = []
self.q_filters = Q()
+ self.order_by_elements: List = []
# used by has()
- self.must_match = {}
- self.dont_match = {}
+ self.must_match: Dict = {}
+ self.dont_match: Dict = {}
- self.relations_to_fetch: list = []
+ self.relations_to_fetch: List = []
+ self._extra_results: List = []
+ self._subqueries: list[Tuple[str, list[str]]] = []
+ self._intermediate_transforms: list = []
def __await__(self):
return self.all().__await__()
@@ -972,7 +1295,7 @@ async def first_or_none(self, **kwargs):
pass
return None
- def filter(self, *args, **kwargs):
+ def filter(self, *args, **kwargs) -> "AsyncBaseSet":
"""
Apply filters to the existing nodes in the set.
@@ -1042,6 +1365,9 @@ def order_by(self, *props):
self.order_by_elements.append("?")
else:
for prop in props:
+ if isinstance(prop, RawCypher):
+ self.order_by_elements.append(prop)
+ continue
prop = prop.strip()
if prop.startswith("-"):
prop = prop[1:]
@@ -1049,31 +1375,185 @@ def order_by(self, *props):
else:
desc = False
- if prop not in self.source_class.defined_properties(rels=False):
- raise ValueError(
- f"No such property {prop} on {self.source_class.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation."
- )
-
- property_obj = getattr(self.source_class, prop)
- if isinstance(property_obj, AliasProperty):
- prop = property_obj.aliased_to()
+ if prop in self.source_class.defined_properties(rels=False):
+ property_obj = getattr(self.source_class, prop)
+ if isinstance(property_obj, AliasProperty):
+ prop = property_obj.aliased_to()
self.order_by_elements.append(prop + (" DESC" if desc else ""))
return self
+ def _register_relation_to_fetch(
+ self,
+ relation_def: Any,
+ alias: TOptional[str] = None,
+ include_in_return: bool = True,
+ ):
+ if isinstance(relation_def, Optional):
+ item = {"path": relation_def.relation, "optional": True}
+ else:
+ item = {"path": relation_def}
+ item["include_in_return"] = include_in_return
+ if alias:
+ item["alias"] = alias
+ return item
+
def fetch_relations(self, *relation_names):
- """Specify a set of relations to return."""
+ """Specify a set of relations to traverse and return."""
relations = []
for relation_name in relation_names:
- if isinstance(relation_name, Optional):
- item = {"path": relation_name.relation, "optional": True}
- else:
- item = {"path": relation_name}
- relations.append(item)
+ relations.append(self._register_relation_to_fetch(relation_name))
+ self.relations_to_fetch = relations
+ return self
+
+ def traverse_relations(self, *relation_names, **aliased_relation_names):
+ """Specify a set of relations to traverse only."""
+ relations = []
+ for relation_name in relation_names:
+ relations.append(
+ self._register_relation_to_fetch(relation_name, include_in_return=False)
+ )
+ for alias, relation_def in aliased_relation_names.items():
+ relations.append(
+ self._register_relation_to_fetch(
+ relation_def, alias, include_in_return=False
+ )
+ )
+
self.relations_to_fetch = relations
return self
+ def annotate(self, *vars, **aliased_vars):
+ """Annotate node set results with extra variables."""
+
+ def register_extra_var(vardef, varname: Union[str, None] = None):
+ if isinstance(vardef, (AggregatingFunction, ScalarFunction)):
+ self._extra_results.append(
+ {"vardef": vardef, "alias": varname if varname else ""}
+ )
+ else:
+ raise NotImplementedError
+
+ for vardef in vars:
+ register_extra_var(vardef)
+ for varname, vardef in aliased_vars.items():
+ register_extra_var(vardef, varname)
+
+ return self
+
+ def _to_subgraph(self, root_node, other_nodes, subgraph):
+ """Recursive method to build root_node's relation graph from subgraph."""
+ root_node._relations = {}
+ for name, relation_def in subgraph.items():
+ for var_name, node in other_nodes.items():
+ if (
+ var_name
+ not in [
+ relation_def["variable_name"],
+ relation_def["rel_variable_name"],
+ ]
+ or node is None
+ ):
+ continue
+ if isinstance(node, list):
+ if len(node) > 0 and isinstance(node[0], AsyncStructuredRel):
+ name += "_relationship"
+ root_node._relations[name] = []
+ for item in node:
+ root_node._relations[name].append(
+ self._to_subgraph(
+ item, other_nodes, relation_def["children"]
+ )
+ )
+ else:
+ if isinstance(node, AsyncStructuredRel):
+ name += "_relationship"
+ root_node._relations[name] = self._to_subgraph(
+ node, other_nodes, relation_def["children"]
+ )
+
+ return root_node
+
+ async def resolve_subgraph(self) -> list:
+ """
+ Convert every result contained in this node set to a subgraph.
+
+ By default, we receive results from neomodel as a list of
+ nodes without the hierarchy. This method tries to rebuild this
+ hierarchy without overriding anything in the node, that's why
+ we use a dedicated property to store node's relations.
+
+ """
+ if (
+ self.relations_to_fetch
+ and not self.relations_to_fetch[0]["include_in_return"]
+ ):
+ raise NotImplementedError(
+ "You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead."
+ )
+ results: list = []
+ qbuilder = self.query_cls(self)
+ await qbuilder.build_ast()
+ if not qbuilder._ast.subgraph:
+ raise RuntimeError(
+ "Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()."
+ )
+ other_nodes = {}
+ root_node = None
+ async for row in qbuilder._execute(dict_output=True):
+ for name, node in row.items():
+ if node.__class__ is self.source and "_" not in name:
+ root_node = node
+ continue
+ if isinstance(node, list) and isinstance(node[0], list):
+ other_nodes[name] = node[0]
+ continue
+ other_nodes[name] = node
+ results.append(
+ self._to_subgraph(root_node, other_nodes, qbuilder._ast.subgraph)
+ )
+ return results
+
+ async def subquery(
+ self, nodeset: "AsyncNodeSet", return_set: List[str]
+ ) -> "AsyncNodeSet":
+ """Add a subquery to this node set.
+
+ A subquery is a regular cypher query but executed within the context of a CALL
+ statement. Such query will generally fetch additional variables which must be
+ declared inside return_set variable in order to be included in the final RETURN
+ statement.
+ """
+ qbuilder = await nodeset.query_cls(nodeset, subquery_context=True).build_ast()
+ for var in return_set:
+ if (
+ var != qbuilder._ast.return_clause
+ and var not in qbuilder._ast.additional_return
+ and var
+ not in [res["alias"] for res in nodeset._extra_results if res["alias"]]
+ ):
+ raise RuntimeError(f"Variable '{var}' is not returned by subquery.")
+ self._subqueries.append((qbuilder.build_query(), return_set))
+ return self
+
+ def intermediate_transform(
+ self, vars: Dict[str, Any], ordering: TOptional[list] = None
+ ) -> "AsyncNodeSet":
+ if not vars:
+ raise ValueError(
+ "You must provide one variable at least when calling intermediate_transform()"
+ )
+ for name, source in vars.items():
+ if type(source) is not str and not isinstance(
+ source, (NodeNameResolver, RelationNameResolver)
+ ):
+ raise ValueError(
+ f"Wrong source type specified for variable '{name}', should be a string or an instance of NodeNameResolver or RelationNameResolver"
+ )
+ self._intermediate_transforms.append({"vars": vars, "ordering": ordering})
+ return self
+
class AsyncTraversal(AsyncBaseSet):
"""
@@ -1087,13 +1567,13 @@ class AsyncTraversal(AsyncBaseSet):
:type name: :class:`str`
:param definition: A relationship definition that most certainly deserves
a documentation here.
- :type defintion: :class:`dict`
+ :type definition: :class:`dict`
"""
def __await__(self):
return self.all().__await__()
- def __init__(self, source, name, definition):
+ def __init__(self, source, name, definition) -> None:
"""
Create a traversal
@@ -1123,7 +1603,7 @@ def __init__(self, source, name, definition):
self.definition = definition
self.target_class = definition["node_class"]
self.name = name
- self.filters = []
+ self.filters: List = []
def match(self, **kwargs):
"""
diff --git a/neomodel/properties.py b/neomodel/properties.py
index d4a91885..e28e2df3 100644
--- a/neomodel/properties.py
+++ b/neomodel/properties.py
@@ -1,9 +1,10 @@
import functools
import json
import re
-import sys
import uuid
+from abc import ABCMeta, abstractmethod
from datetime import date, datetime
+from typing import Any
import neo4j.time
import pytz
@@ -37,7 +38,7 @@ def _validator(self, value, obj=None, rethrow=True):
return _validator
-class FulltextIndex(object):
+class FulltextIndex:
"""
Fulltext index definition
"""
@@ -57,7 +58,7 @@ def __init__(
self.eventually_consistent = eventually_consistent
-class VectorIndex(object):
+class VectorIndex:
"""
Vector index definition
"""
@@ -73,7 +74,7 @@ def __init__(self, dimensions=1536, similarity_function="cosine"):
self.similarity_function = similarity_function
-class Property:
+class Property(metaclass=ABCMeta):
"""
Base class for object properties.
@@ -158,6 +159,10 @@ def get_db_property_name(self, attribute_name):
def is_indexed(self):
return self.unique_index or self.index
+ @abstractmethod
+ def deflate(self, value: Any) -> Any:
+ pass
+
class NormalizedProperty(Property):
"""
@@ -413,7 +418,7 @@ class DateTimeFormatProperty(Property):
"""
Store a datetime by custom format
:param default_now: If ``True``, the creation time (Local) will be used as default.
- Defaults to ``False``.
+ Defaults to ``False``.
:param format: Date format string, default is %Y-%m-%d
:type default_now: :class:`bool`
@@ -529,8 +534,9 @@ class JSONProperty(Property):
The structure will be inflated when a node is retrieved.
"""
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
+ def __init__(self, ensure_ascii=True, *args, **kwargs):
+ self.ensure_ascii = ensure_ascii
+ super(JSONProperty, self).__init__(*args, **kwargs)
@validator
def inflate(self, value):
@@ -538,7 +544,7 @@ def inflate(self, value):
@validator
def deflate(self, value):
- return json.dumps(value)
+ return json.dumps(value, ensure_ascii=self.ensure_ascii)
class AliasProperty(property, Property):
diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py
index 928842c2..966b2601 100644
--- a/neomodel/sync_/match.py
+++ b/neomodel/sync_/match.py
@@ -1,15 +1,21 @@
import inspect
import re
-from collections import defaultdict
+import string
from dataclasses import dataclass
-from typing import Optional
+from typing import Any, Dict, List
+from typing import Optional as TOptional
+from typing import Tuple, Union
from neomodel.exceptions import MultipleNodesReturned
from neomodel.match_q import Q, QBase
-from neomodel.properties import AliasProperty, ArrayProperty
+from neomodel.properties import AliasProperty, ArrayProperty, Property
+from neomodel.sync_ import relationship_manager
from neomodel.sync_.core import StructuredNode, db
+from neomodel.sync_.relationship import StructuredRel
from neomodel.util import INCOMING, OUTGOING
+CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR = re.compile(r"(?i:MERGE|CREATE|DELETE|DETACH)")
+
def _rel_helper(
lhs,
@@ -194,6 +200,8 @@ def _rel_merge_helper(
# add all regex operators
OPERATOR_TABLE.update(_REGEX_OPERATOR_TABLE)
+path_split_regex = re.compile(r"__(?!_)|\|")
+
def install_traversals(cls, node_set):
"""
@@ -213,136 +221,122 @@ def install_traversals(cls, node_set):
setattr(node_set, key, traversal)
-def process_filter_args(cls, kwargs):
- """
- loop through properties in filter parameters check they match class definition
- deflate them and convert into something easy to generate cypher from
- """
-
- output = {}
-
- for key, value in kwargs.items():
- if "__" in key:
- prop, operator = key.rsplit("__")
- operator = OPERATOR_TABLE[operator]
- else:
- prop = key
- operator = "="
-
- if prop not in cls.defined_properties(rels=False):
+def _handle_special_operators(
+ property_obj: Property, key: str, value: str, operator: str, prop: str
+) -> Tuple[str, str, str]:
+ if operator == _SPECIAL_OPERATOR_IN:
+ if not isinstance(value, (list, tuple)):
raise ValueError(
- f"No such property {prop} on {cls.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation."
+ f"Value must be a tuple or list for IN operation {key}={value}"
)
-
- property_obj = getattr(cls, prop)
- if isinstance(property_obj, AliasProperty):
- prop = property_obj.aliased_to()
- deflated_value = getattr(cls, prop).deflate(value)
+ if isinstance(property_obj, ArrayProperty):
+ deflated_value = property_obj.deflate(value)
+ operator = _SPECIAL_OPERATOR_ARRAY_IN
else:
- operator, deflated_value = transform_operator_to_filter(
- operator=operator,
- filter_key=key,
- filter_value=value,
- property_obj=property_obj,
- )
-
- # map property to correct property name in the database
- db_property = cls.defined_properties(rels=False)[prop].get_db_property_name(
- prop
- )
-
- output[db_property] = (operator, deflated_value)
+ deflated_value = [property_obj.deflate(v) for v in value]
+ elif operator == _SPECIAL_OPERATOR_ISNULL:
+ if not isinstance(value, bool):
+ raise ValueError(f"Value must be a bool for isnull operation on {key}")
+ operator = "IS NULL" if value else "IS NOT NULL"
+ deflated_value = None
+ elif operator in _REGEX_OPERATOR_TABLE.values():
+ deflated_value = property_obj.deflate(value)
+ if not isinstance(deflated_value, str):
+ raise ValueError(f"Must be a string value for {key}")
+ if operator in _STRING_REGEX_OPERATOR_TABLE.values():
+ deflated_value = re.escape(deflated_value)
+ deflated_value = operator.format(deflated_value)
+ operator = _SPECIAL_OPERATOR_REGEX
+ else:
+ deflated_value = property_obj.deflate(value)
- return output
+ return deflated_value, operator, prop
-def transform_in_operator_to_filter(operator, filter_key, filter_value, property_obj):
- """
- Transform in operator to a cypher filter
- Args:
- operator (str): operator to transform
- filter_key (str): filter key
- filter_value (str): filter value
- property_obj (object): property object
- Returns:
- tuple: operator, deflated_value
- """
- if not isinstance(filter_value, tuple) and not isinstance(filter_value, list):
- raise ValueError(
- f"Value must be a tuple or list for IN operation {filter_key}={filter_value}"
- )
- if isinstance(property_obj, ArrayProperty):
- deflated_value = property_obj.deflate(filter_value)
- operator = _SPECIAL_OPERATOR_ARRAY_IN
+def _deflate_value(
+ cls, property_obj: Property, key: str, value: str, operator: str, prop: str
+) -> Tuple[str, str, str]:
+ if isinstance(property_obj, AliasProperty):
+ prop = property_obj.aliased_to()
+ deflated_value = getattr(cls, prop).deflate(value)
else:
- deflated_value = [property_obj.deflate(v) for v in filter_value]
+ # handle special operators
+ deflated_value, operator, prop = _handle_special_operators(
+ property_obj, key, value, operator, prop
+ )
- return operator, deflated_value
+ return deflated_value, operator, prop
+
+
+def _initialize_filter_args_variables(cls, key: str):
+ current_class = cls
+ current_rel_model = None
+ leaf_prop = None
+ operator = "="
+ is_rel_property = "|" in key
+ prop = key
+
+ return current_class, current_rel_model, leaf_prop, operator, is_rel_property, prop
+
+
+def _process_filter_key(cls, key: str) -> Tuple[Property, str, str]:
+ (
+ current_class,
+ current_rel_model,
+ leaf_prop,
+ operator,
+ is_rel_property,
+ prop,
+ ) = _initialize_filter_args_variables(cls, key)
+
+ for part in re.split(path_split_regex, key):
+ defined_props = current_class.defined_properties(rels=True)
+ # update defined props dictionary with relationship properties if
+ # we are filtering by property
+ if is_rel_property and current_rel_model:
+ defined_props.update(current_rel_model.defined_properties(rels=True))
+ if part in defined_props:
+ if isinstance(
+ defined_props[part], relationship_manager.RelationshipDefinition
+ ):
+ defined_props[part].lookup_node_class()
+ current_class = defined_props[part].definition["node_class"]
+ current_rel_model = defined_props[part].definition["model"]
+ elif part in OPERATOR_TABLE:
+ operator = OPERATOR_TABLE[part]
+ prop, _ = prop.rsplit("__", 1)
+ continue
+ else:
+ raise ValueError(
+ f"No such property {part} on {cls.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation."
+ )
+ leaf_prop = part
+ if is_rel_property and current_rel_model:
+ property_obj = getattr(current_rel_model, leaf_prop)
+ else:
+ property_obj = getattr(current_class, leaf_prop)
-def transform_null_operator_to_filter(filter_key, filter_value):
- """
- Transform null operator to a cypher filter
- Args:
- filter_key (str): filter key
- filter_value (str): filter value
- Returns:
- tuple: operator, deflated_value
- """
- if not isinstance(filter_value, bool):
- raise ValueError(f"Value must be a bool for isnull operation on {filter_key}")
- operator = "IS NULL" if filter_value else "IS NOT NULL"
- deflated_value = None
- return operator, deflated_value
+ return property_obj, operator, prop
-def transform_regex_operator_to_filter(
- operator, filter_key, filter_value, property_obj
-):
+def process_filter_args(cls, kwargs) -> Dict:
"""
- Transform regex operator to a cypher filter
- Args:
- operator (str): operator to transform
- filter_key (str): filter key
- filter_value (str): filter value
- property_obj (object): property object
- Returns:
- tuple: operator, deflated_value
+ loop through properties in filter parameters check they match class definition
+ deflate them and convert into something easy to generate cypher from
"""
+ output = {}
- deflated_value = property_obj.deflate(filter_value)
- if not isinstance(deflated_value, str):
- raise ValueError(f"Must be a string value for {filter_key}")
- if operator in _STRING_REGEX_OPERATOR_TABLE.values():
- deflated_value = re.escape(deflated_value)
- deflated_value = operator.format(deflated_value)
- operator = _SPECIAL_OPERATOR_REGEX
- return operator, deflated_value
-
-
-def transform_operator_to_filter(operator, filter_key, filter_value, property_obj):
- if operator == _SPECIAL_OPERATOR_IN:
- operator, deflated_value = transform_in_operator_to_filter(
- operator=operator,
- filter_key=filter_key,
- filter_value=filter_value,
- property_obj=property_obj,
- )
- elif operator == _SPECIAL_OPERATOR_ISNULL:
- operator, deflated_value = transform_null_operator_to_filter(
- filter_key=filter_key, filter_value=filter_value
- )
- elif operator in _REGEX_OPERATOR_TABLE.values():
- operator, deflated_value = transform_regex_operator_to_filter(
- operator=operator,
- filter_key=filter_key,
- filter_value=filter_value,
- property_obj=property_obj,
+ for key, value in kwargs.items():
+ property_obj, operator, prop = _process_filter_key(cls, key)
+ deflated_value, operator, prop = _deflate_value(
+ cls, property_obj, key, value, operator, prop
)
- else:
- deflated_value = property_obj.deflate(filter_value)
+ # map property to correct property name in the database
+ db_property = prop
- return operator, deflated_value
+ output[db_property] = (operator, deflated_value)
+ return output
def process_has_args(cls, kwargs):
@@ -374,34 +368,34 @@ def process_has_args(cls, kwargs):
class QueryAST:
- match: Optional[list]
- optional_match: Optional[list]
- where: Optional[list]
- with_clause: Optional[str]
- return_clause: Optional[str]
- order_by: Optional[str]
- skip: Optional[int]
- limit: Optional[int]
- result_class: Optional[type]
- lookup: Optional[str]
- additional_return: Optional[list]
- is_count: Optional[bool]
+ match: List[str]
+ optional_match: List[str]
+ where: List[str]
+ with_clause: TOptional[str]
+ return_clause: TOptional[str]
+ order_by: TOptional[List[str]]
+ skip: TOptional[int]
+ limit: TOptional[int]
+ result_class: TOptional[type]
+ lookup: TOptional[str]
+ additional_return: List[str]
+ is_count: TOptional[bool]
def __init__(
self,
- match: Optional[list] = None,
- optional_match: Optional[list] = None,
- where: Optional[list] = None,
- with_clause: Optional[str] = None,
- return_clause: Optional[str] = None,
- order_by: Optional[str] = None,
- skip: Optional[int] = None,
- limit: Optional[int] = None,
- result_class: Optional[type] = None,
- lookup: Optional[str] = None,
- additional_return: Optional[list] = None,
- is_count: Optional[bool] = False,
- ):
+ match: TOptional[List[str]] = None,
+ optional_match: TOptional[List[str]] = None,
+ where: TOptional[List[str]] = None,
+ with_clause: TOptional[str] = None,
+ return_clause: TOptional[str] = None,
+ order_by: TOptional[List[str]] = None,
+ skip: TOptional[int] = None,
+ limit: TOptional[int] = None,
+ result_class: TOptional[type] = None,
+ lookup: TOptional[str] = None,
+ additional_return: TOptional[List[str]] = None,
+ is_count: TOptional[bool] = False,
+ ) -> None:
self.match = match if match else []
self.optional_match = optional_match if optional_match else []
self.where = where if where else []
@@ -414,18 +408,19 @@ def __init__(
self.lookup = lookup
self.additional_return = additional_return if additional_return else []
self.is_count = is_count
+ self.subgraph: Dict = {}
class QueryBuilder:
- def __init__(self, node_set):
+ def __init__(self, node_set, subquery_context: bool = False) -> None:
self.node_set = node_set
self._ast = QueryAST()
- self._query_params = {}
- self._place_holder_registry = {}
- self._ident_count = 0
- self._node_counters = defaultdict(int)
+ self._query_params: Dict = {}
+ self._place_holder_registry: Dict = {}
+ self._ident_count: int = 0
+ self._subquery_context: bool = subquery_context
- def build_ast(self):
+ def build_ast(self) -> "QueryBuilder":
if hasattr(self.node_set, "relations_to_fetch"):
for relation in self.node_set.relations_to_fetch:
self.build_traversal_from_path(relation, self.node_set.source)
@@ -439,7 +434,7 @@ def build_ast(self):
return self
- def build_source(self, source):
+ def build_source(self, source) -> str:
if isinstance(source, Traversal):
return self.build_traversal(source)
if isinstance(source, NodeSet):
@@ -468,18 +463,40 @@ def build_source(self, source):
return self.build_node(source)
raise ValueError("Unknown source type " + repr(source))
- def create_ident(self):
+ def create_ident(self) -> str:
self._ident_count += 1
- return "r" + str(self._ident_count)
+ return f"r{self._ident_count}"
- def build_order_by(self, ident, source):
+ def build_order_by(self, ident: str, source: "NodeSet") -> None:
if "?" in source.order_by_elements:
self._ast.with_clause = f"{ident}, rand() as r"
- self._ast.order_by = "r"
+ self._ast.order_by = ["r"]
else:
- self._ast.order_by = [f"{ident}.{p}" for p in source.order_by_elements]
+ order_by = []
+ for elm in source.order_by_elements:
+ if isinstance(elm, RawCypher):
+ order_by.append(elm.render({"n": ident}))
+ continue
+ is_rel_property = "|" in elm
+ if "__" not in elm and not is_rel_property:
+ prop = elm.split(" ")[0] if " " in elm else elm
+ if prop not in source.source_class.defined_properties(rels=False):
+ raise ValueError(
+ f"No such property {prop} on {source.source_class.__name__}. "
+ f"Note that Neo4j internals like id or element_id are not allowed "
+ f"for use in this operation."
+ )
+ order_by.append(f"{ident}.{elm}")
+ else:
+ path, prop = elm.rsplit("__" if not is_rel_property else "|", 1)
+ result = self.lookup_query_variable(
+ path, return_relation=is_rel_property
+ )
+ if result:
+ order_by.append(f"{result[0]}.{prop}")
+ self._ast.order_by = order_by
- def build_traversal(self, traversal):
+ def build_traversal(self, traversal) -> str:
"""
traverse a relationship from a node to a set of nodes
"""
@@ -507,27 +524,29 @@ def build_traversal(self, traversal):
return traversal_ident
- def _additional_return(self, name):
+ def _additional_return(self, name: str):
if name not in self._ast.additional_return and name != self._ast.return_clause:
self._ast.additional_return.append(name)
- def build_traversal_from_path(self, relation: dict, source_class) -> str:
+ def build_traversal_from_path(
+ self, relation: dict, source_class
+ ) -> Tuple[str, Any]:
path: str = relation["path"]
stmt: str = ""
source_class_iterator = source_class
- for index, part in enumerate(path.split("__")):
+ parts = re.split(path_split_regex, path)
+ subgraph = self._ast.subgraph
+ rel_iterator: str = ""
+ already_present = False
+ existing_rhs_name = ""
+ for index, part in enumerate(parts):
relationship = getattr(source_class_iterator, part)
+ if rel_iterator:
+ rel_iterator += "__"
+ rel_iterator += part
# build source
if "node_class" not in relationship.definition:
relationship.lookup_node_class()
- rhs_label = relationship.definition["node_class"].__label__
- rel_reference = f'{relationship.definition["node_class"]}_{part}'
- self._node_counters[rel_reference] += 1
- rhs_name = (
- f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}"
- )
- rhs_ident = f"{rhs_name}:{rhs_label}"
- self._additional_return(rhs_name)
if not stmt:
lhs_label = source_class_iterator.__label__
lhs_name = lhs_label.lower()
@@ -537,13 +556,46 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str:
# contains the primary node so _contains() works
# as usual
self._ast.return_clause = lhs_name
- else:
+ if self._subquery_context:
+ # Don't include label in identifier if we are in a subquery
+ lhs_ident = lhs_name
+ elif relation["include_in_return"]:
self._additional_return(lhs_name)
else:
lhs_ident = stmt
+ already_present = part in subgraph
rel_ident = self.create_ident()
- self._additional_return(rel_ident)
+ rhs_label = relationship.definition["node_class"].__label__
+ if relation.get("relation_filtering"):
+ rhs_name = rel_ident
+ else:
+ if index + 1 == len(parts) and "alias" in relation:
+ # If an alias is defined, use it to store the last hop in the path
+ rhs_name = relation["alias"]
+ else:
+ rhs_name = f"{rhs_label.lower()}_{rel_iterator}"
+ rhs_ident = f"{rhs_name}:{rhs_label}"
+ if relation["include_in_return"] and not already_present:
+ self._additional_return(rhs_name)
+
+ if not already_present:
+ subgraph[part] = {
+ "target": relationship.definition["node_class"],
+ "children": {},
+ "variable_name": rhs_name,
+ "rel_variable_name": rel_ident,
+ }
+ else:
+ existing_rhs_name = subgraph[part][
+ (
+ "rel_variable_name"
+ if relation.get("relation_filtering")
+ else "variable_name"
+ )
+ ]
+ if relation["include_in_return"] and not already_present:
+ self._additional_return(rel_ident)
stmt = _rel_helper(
lhs=lhs_ident,
rhs=rhs_ident,
@@ -552,12 +604,16 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str:
relation_type=relationship.definition["relation_type"],
)
source_class_iterator = relationship.definition["node_class"]
+ subgraph = subgraph[part]["children"]
- if relation.get("optional"):
- self._ast.optional_match.append(stmt)
- else:
- self._ast.match.append(stmt)
- return rhs_name
+ if not already_present:
+ if relation.get("optional"):
+ self._ast.optional_match.append(stmt)
+ else:
+ self._ast.match.append(stmt)
+ return rhs_name, relationship.definition["node_class"]
+
+ return existing_rhs_name, relationship.definition["node_class"]
def build_node(self, node):
ident = node.__class__.__name__.lower()
@@ -573,7 +629,7 @@ def build_node(self, node):
self._ast.result_class = node.__class__
return ident
- def build_label(self, ident, cls):
+ def build_label(self, ident, cls) -> str:
"""
match nodes by a label
"""
@@ -609,14 +665,71 @@ def build_additional_match(self, ident, node_set):
else:
raise ValueError("Expecting dict got: " + repr(val))
- def _register_place_holder(self, key):
+ def _register_place_holder(self, key: str) -> str:
if key in self._place_holder_registry:
self._place_holder_registry[key] += 1
else:
self._place_holder_registry[key] = 1
return key + "_" + str(self._place_holder_registry[key])
- def _parse_q_filters(self, ident, q, source_class):
+ def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str, Any]:
+ is_rel_filter = "|" in prop
+ if is_rel_filter:
+ path, prop = prop.rsplit("|", 1)
+ else:
+ path, prop = prop.rsplit("__", 1)
+ result = self.lookup_query_variable(path, return_relation=is_rel_filter)
+ if not result:
+ ident, target_class = self.build_traversal_from_path(
+ {
+ "path": path,
+ "include_in_return": True,
+ "relation_filtering": is_rel_filter,
+ },
+ source_class,
+ )
+ else:
+ ident, target_class = result
+ return ident, path, prop, target_class
+
+ def _finalize_filter_statement(
+ self, operator: str, ident: str, prop: str, val: Any
+ ) -> str:
+ if operator in _UNARY_OPERATORS:
+ # unary operators do not have a parameter
+ statement = f"{ident}.{prop} {operator}"
+ else:
+ place_holder = self._register_place_holder(ident + "_" + prop)
+ if operator == _SPECIAL_OPERATOR_ARRAY_IN:
+ statement = operator.format(
+ ident=ident,
+ prop=prop,
+ val=f"${place_holder}",
+ )
+ else:
+ statement = f"{ident}.{prop} {operator} ${place_holder}"
+ self._query_params[place_holder] = val
+
+ return statement
+
+ def _build_filter_statements(
+ self, ident: str, filters, target: List[str], source_class
+ ) -> None:
+ for prop, op_and_val in filters.items():
+ path = None
+ is_rel_filter = "|" in prop
+ target_class = source_class
+ if "__" in prop or is_rel_filter:
+ ident, path, prop, target_class = self._parse_path(source_class, prop)
+ operator, val = op_and_val
+ if not is_rel_filter:
+ prop = target_class.defined_properties(rels=False)[
+ prop
+ ].get_db_property_name(prop)
+ statement = self._finalize_filter_statement(operator, ident, prop, val)
+ target.append(statement)
+
+ def _parse_q_filters(self, ident, q, source_class) -> str:
target = []
for child in q.children:
if isinstance(child, QBase):
@@ -627,36 +740,22 @@ def _parse_q_filters(self, ident, q, source_class):
else:
kwargs = {child[0]: child[1]}
filters = process_filter_args(source_class, kwargs)
- for prop, op_and_val in filters.items():
- operator, val = op_and_val
- if operator in _UNARY_OPERATORS:
- # unary operators do not have a parameter
- statement = f"{ident}.{prop} {operator}"
- else:
- place_holder = self._register_place_holder(ident + "_" + prop)
- if operator == _SPECIAL_OPERATOR_ARRAY_IN:
- statement = operator.format(
- ident=ident,
- prop=prop,
- val=f"${place_holder}",
- )
- else:
- statement = f"{ident}.{prop} {operator} ${place_holder}"
- self._query_params[place_holder] = val
- target.append(statement)
+ self._build_filter_statements(ident, filters, target, source_class)
ret = f" {q.connector} ".join(target)
if q.negated:
ret = f"NOT ({ret})"
return ret
- def build_where_stmt(self, ident, filters, q_filters=None, source_class=None):
+ def build_where_stmt(
+ self, ident: str, filters, q_filters=None, source_class=None
+ ) -> None:
"""
construct a where statement from some filters
"""
if q_filters is not None:
- stmts = self._parse_q_filters(ident, q_filters, source_class)
- if stmts:
- self._ast.where.append(stmts)
+ stmt = self._parse_q_filters(ident, q_filters, source_class)
+ if stmt:
+ self._ast.where.append(stmt)
else:
stmts = []
for row in filters:
@@ -682,8 +781,37 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None):
self._ast.where.append(" AND ".join(stmts))
- def build_query(self):
- query = ""
+ def lookup_query_variable(
+ self, path: str, return_relation: bool = False
+ ) -> TOptional[Tuple[str, Any]]:
+ """Retrieve the variable name generated internally for the given traversal path."""
+ subgraph = self._ast.subgraph
+ if not subgraph:
+ return None
+ traversals = re.split(path_split_regex, path)
+ if len(traversals) == 0:
+ raise ValueError("Can only lookup traversal variables")
+ if traversals[0] not in subgraph:
+ return None
+ subgraph = subgraph[traversals[0]]
+ if len(traversals) == 1:
+ variable_to_return = f"{subgraph['rel_variable_name' if return_relation else 'variable_name']}"
+ return variable_to_return, subgraph["target"]
+ variable_to_return = ""
+ last_property = traversals[-1]
+ for part in traversals[1:]:
+ child = subgraph["children"].get(part)
+ if not child:
+ return None
+ subgraph = child
+ if part == last_property:
+ # if last part of prop is the last traversal
+ # we are safe to lookup the variable from the query
+ variable_to_return = f"{subgraph['rel_variable_name' if return_relation else 'variable_name']}"
+ return variable_to_return, subgraph["target"]
+
+ def build_query(self) -> str:
+ query: str = ""
if self._ast.lookup:
query += self._ast.lookup
@@ -702,6 +830,9 @@ def build_query(self):
query += " OPTIONAL MATCH ".join(i for i in self._ast.optional_match)
if self._ast.where:
+ if self._ast.optional_match:
+ # Make sure filtering works as expected with optional match, even if it's not performant...
+ query += " WITH *"
query += " WHERE "
query += " AND ".join(self._ast.where)
@@ -709,13 +840,85 @@ def build_query(self):
query += " WITH "
query += self._ast.with_clause
+ if hasattr(self.node_set, "_intermediate_transforms"):
+ for transform in self.node_set._intermediate_transforms:
+ query += " WITH "
+ injected_vars: list = []
+ # Reset return list since we'll probably invalidate most variables
+ self._ast.return_clause = ""
+ self._ast.additional_return = []
+ for name, source in transform["vars"].items():
+ if type(source) is str:
+ injected_vars.append(f"{source} AS {name}")
+ elif isinstance(source, RelationNameResolver):
+ result = self.lookup_query_variable(
+ source.relation, return_relation=True
+ )
+ if not result:
+ raise ValueError(
+ f"Unable to resolve variable name for relation {source.relation}."
+ )
+ injected_vars.append(f"{result[0]} AS {name}")
+ elif isinstance(source, NodeNameResolver):
+ result = self.lookup_query_variable(source.node)
+ if not result:
+ raise ValueError(
+ f"Unable to resolve variable name for node {source.node}."
+ )
+ injected_vars.append(f"{result[0]} AS {name}")
+ query += ",".join(injected_vars)
+ if not transform["ordering"]:
+ continue
+ query += " ORDER BY "
+ ordering: list = []
+ for item in transform["ordering"]:
+ if isinstance(item, RawCypher):
+ ordering.append(item.render({}))
+ continue
+ if item.startswith("-"):
+ ordering.append(f"{item[1:]} DESC")
+ else:
+ ordering.append(item)
+ query += ",".join(ordering)
+
+ returned_items: list[str] = []
+ if hasattr(self.node_set, "_subqueries"):
+ for subquery, return_set in self.node_set._subqueries:
+ outer_primary_var = self._ast.return_clause
+ query += f" CALL {{ WITH {outer_primary_var} {subquery} }} "
+ for varname in return_set:
+ # We declare the returned variables as "virtual" relations of the
+ # root node class to make sure they will be translated by a call to
+ # resolve_subgraph() (otherwise, they will be lost).
+ # This is probably a temporary solution until we find something better...
+ self._ast.subgraph[varname] = {
+ "target": None, # We don't need target class in this use case
+ "children": {},
+ "variable_name": varname,
+ "rel_variable_name": varname,
+ }
+ returned_items += return_set
+
query += " RETURN "
- if self._ast.return_clause:
- query += self._ast.return_clause
+ if self._ast.return_clause and not self._subquery_context:
+ returned_items.append(self._ast.return_clause)
if self._ast.additional_return:
- if self._ast.return_clause:
- query += ", "
- query += ", ".join(self._ast.additional_return)
+ returned_items += self._ast.additional_return
+ if hasattr(self.node_set, "_extra_results"):
+ for props in self.node_set._extra_results:
+ leftpart = props["vardef"].render(self)
+ varname = (
+ props["alias"]
+ if props.get("alias")
+ else props["vardef"].get_internal_name()
+ )
+ if varname in returned_items:
+ # We're about to override an existing variable, delete it first to
+ # avoid duplicate error
+ returned_items.remove(varname)
+ returned_items.append(f"{leftpart} AS {varname}")
+
+ query += ", ".join(returned_items)
if self._ast.order_by:
query += " ORDER BY "
@@ -754,7 +957,6 @@ def _count(self):
def _contains(self, node_element_id):
# inject id = into ast
if not self._ast.return_clause:
- print(self._ast.additional_return)
self._ast.return_clause = self._ast.additional_return[0]
ident = self._ast.return_clause
place_holder = self._register_place_holder(ident + "_contains")
@@ -762,7 +964,7 @@ def _contains(self, node_element_id):
self._query_params[place_holder] = node_element_id
return self._count() >= 1
- def _execute(self, lazy=False):
+ def _execute(self, lazy: bool = False, dict_output: bool = False):
if lazy:
# inject id() into return or return_set
if self._ast.return_clause:
@@ -775,7 +977,13 @@ def _execute(self, lazy=False):
for item in self._ast.additional_return
]
query = self.build_query()
- results, _ = db.cypher_query(query, self._query_params, resolve_objects=True)
+ results, prop_names = db.cypher_query(
+ query, self._query_params, resolve_objects=True
+ )
+ if dict_output:
+ for item in results:
+ yield dict(zip(prop_names, item))
+ return
# The following is not as elegant as it could be but had to be copied from the
# version prior to cypher_query with the resolve_objects capability.
# It seems that certain calls are only supposed to be focusing to the first
@@ -796,6 +1004,7 @@ class BaseSet:
"""
query_cls = QueryBuilder
+ source_class: StructuredNode
def all(self, lazy=False):
"""
@@ -819,7 +1028,7 @@ def __len__(self):
ast = self.query_cls(self).build_ast()
return ast._count()
- def __bool__(self):
+ def __bool__(self) -> bool:
"""
Override for __bool__ dunder method.
:return: True if the set contains any nodes, False otherwise
@@ -829,7 +1038,7 @@ def __bool__(self):
_count = ast._count()
return _count > 0
- def __nonzero__(self):
+ def __nonzero__(self) -> bool:
"""
Override for __bool__ dunder method.
:return: True if the set contains any node, False otherwise
@@ -877,12 +1086,126 @@ class Optional:
relation: str
+@dataclass
+class RelationNameResolver:
+ """Helper to refer to a relation variable name.
+
+ Since variable names are generated automatically within MATCH statements (for
+ anything injected using fetch_relations or traverse_relations), we need a way to
+ retrieve them.
+
+ """
+
+ relation: str
+
+
+@dataclass
+class NodeNameResolver:
+ """Helper to refer to a node variable name.
+
+ Since variable names are generated automatically within MATCH statements (for
+ anything injected using fetch_relations or traverse_relations), we need a way to
+ retrieve them.
+
+ """
+
+ node: str
+
+
+@dataclass
+class BaseFunction:
+ input_name: Union[str, "BaseFunction", NodeNameResolver, RelationNameResolver]
+
+ def __post_init__(self) -> None:
+ self._internal_name: str = ""
+
+ def get_internal_name(self) -> str:
+ return self._internal_name
+
+ def resolve_internal_name(self, qbuilder: QueryBuilder) -> str:
+ if isinstance(self.input_name, NodeNameResolver):
+ result = qbuilder.lookup_query_variable(self.input_name.node)
+ elif isinstance(self.input_name, RelationNameResolver):
+ result = qbuilder.lookup_query_variable(self.input_name.relation, True)
+ else:
+ result = (str(self.input_name), None)
+ if result is None:
+ raise ValueError(f"Unknown variable {self.input_name} used in Collect()")
+ self._internal_name = result[0]
+ return self._internal_name
+
+ def render(self, qbuilder: QueryBuilder) -> str:
+ raise NotImplementedError
+
+
+@dataclass
+class AggregatingFunction(BaseFunction):
+ """Base aggregating function class."""
+
+ pass
+
+
+@dataclass
+class Collect(AggregatingFunction):
+ """collect() function."""
+
+ distinct: bool = False
+
+ def render(self, qbuilder: QueryBuilder) -> str:
+ varname = self.resolve_internal_name(qbuilder)
+ if self.distinct:
+ return f"collect(DISTINCT {varname})"
+ return f"collect({varname})"
+
+
+@dataclass
+class ScalarFunction(BaseFunction):
+ """Base scalar function class."""
+
+ pass
+
+
+@dataclass
+class Last(ScalarFunction):
+ """last() function."""
+
+ def render(self, qbuilder: QueryBuilder) -> str:
+ if isinstance(self.input_name, str):
+ content = str(self.input_name)
+ elif isinstance(self.input_name, BaseFunction):
+ content = self.input_name.render(qbuilder)
+ self._internal_name = self.input_name.get_internal_name()
+ else:
+ content = self.resolve_internal_name(qbuilder)
+ return f"last({content})"
+
+
+@dataclass
+class RawCypher:
+ """Helper to inject raw cypher statement.
+
+ Can be used in order_by() call for example.
+
+ """
+
+ statement: str
+
+ def __post_init__(self):
+ if CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR.search(self.statement):
+ raise ValueError(
+ "RawCypher: Do not include any action that has side effect"
+ )
+
+ def render(self, context: Dict) -> str:
+ return string.Template(self.statement).substitute(context)
+
+
class NodeSet(BaseSet):
"""
A class representing as set of nodes matching common query parameters
"""
- def __init__(self, source):
+ def __init__(self, source) -> None:
self.source = source # could be a Traverse object or a node class
if isinstance(source, Traversal):
self.source_class = source.target_class
@@ -896,14 +1219,18 @@ def __init__(self, source):
# setup Traversal objects using relationship definitions
install_traversals(self.source_class, self)
- self.filters = []
+ self.filters: List = []
self.q_filters = Q()
+ self.order_by_elements: List = []
# used by has()
- self.must_match = {}
- self.dont_match = {}
+ self.must_match: Dict = {}
+ self.dont_match: Dict = {}
- self.relations_to_fetch: list = []
+ self.relations_to_fetch: List = []
+ self._extra_results: List = []
+ self._subqueries: list[Tuple[str, list[str]]] = []
+ self._intermediate_transforms: list = []
def __await__(self):
return self.all().__await__()
@@ -968,7 +1295,7 @@ def first_or_none(self, **kwargs):
pass
return None
- def filter(self, *args, **kwargs):
+ def filter(self, *args, **kwargs) -> "BaseSet":
"""
Apply filters to the existing nodes in the set.
@@ -1038,6 +1365,9 @@ def order_by(self, *props):
self.order_by_elements.append("?")
else:
for prop in props:
+ if isinstance(prop, RawCypher):
+ self.order_by_elements.append(prop)
+ continue
prop = prop.strip()
if prop.startswith("-"):
prop = prop[1:]
@@ -1045,31 +1375,183 @@ def order_by(self, *props):
else:
desc = False
- if prop not in self.source_class.defined_properties(rels=False):
- raise ValueError(
- f"No such property {prop} on {self.source_class.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation."
- )
-
- property_obj = getattr(self.source_class, prop)
- if isinstance(property_obj, AliasProperty):
- prop = property_obj.aliased_to()
+ if prop in self.source_class.defined_properties(rels=False):
+ property_obj = getattr(self.source_class, prop)
+ if isinstance(property_obj, AliasProperty):
+ prop = property_obj.aliased_to()
self.order_by_elements.append(prop + (" DESC" if desc else ""))
return self
+ def _register_relation_to_fetch(
+ self,
+ relation_def: Any,
+ alias: TOptional[str] = None,
+ include_in_return: bool = True,
+ ):
+ if isinstance(relation_def, Optional):
+ item = {"path": relation_def.relation, "optional": True}
+ else:
+ item = {"path": relation_def}
+ item["include_in_return"] = include_in_return
+ if alias:
+ item["alias"] = alias
+ return item
+
def fetch_relations(self, *relation_names):
- """Specify a set of relations to return."""
+ """Specify a set of relations to traverse and return."""
relations = []
for relation_name in relation_names:
- if isinstance(relation_name, Optional):
- item = {"path": relation_name.relation, "optional": True}
- else:
- item = {"path": relation_name}
- relations.append(item)
+ relations.append(self._register_relation_to_fetch(relation_name))
self.relations_to_fetch = relations
return self
+ def traverse_relations(self, *relation_names, **aliased_relation_names):
+ """Specify a set of relations to traverse only."""
+ relations = []
+ for relation_name in relation_names:
+ relations.append(
+ self._register_relation_to_fetch(relation_name, include_in_return=False)
+ )
+ for alias, relation_def in aliased_relation_names.items():
+ relations.append(
+ self._register_relation_to_fetch(
+ relation_def, alias, include_in_return=False
+ )
+ )
+
+ self.relations_to_fetch = relations
+ return self
+
+ def annotate(self, *vars, **aliased_vars):
+ """Annotate node set results with extra variables."""
+
+ def register_extra_var(vardef, varname: Union[str, None] = None):
+ if isinstance(vardef, (AggregatingFunction, ScalarFunction)):
+ self._extra_results.append(
+ {"vardef": vardef, "alias": varname if varname else ""}
+ )
+ else:
+ raise NotImplementedError
+
+ for vardef in vars:
+ register_extra_var(vardef)
+ for varname, vardef in aliased_vars.items():
+ register_extra_var(vardef, varname)
+
+ return self
+
+ def _to_subgraph(self, root_node, other_nodes, subgraph):
+ """Recursive method to build root_node's relation graph from subgraph."""
+ root_node._relations = {}
+ for name, relation_def in subgraph.items():
+ for var_name, node in other_nodes.items():
+ if (
+ var_name
+ not in [
+ relation_def["variable_name"],
+ relation_def["rel_variable_name"],
+ ]
+ or node is None
+ ):
+ continue
+ if isinstance(node, list):
+ if len(node) > 0 and isinstance(node[0], StructuredRel):
+ name += "_relationship"
+ root_node._relations[name] = []
+ for item in node:
+ root_node._relations[name].append(
+ self._to_subgraph(
+ item, other_nodes, relation_def["children"]
+ )
+ )
+ else:
+ if isinstance(node, StructuredRel):
+ name += "_relationship"
+ root_node._relations[name] = self._to_subgraph(
+ node, other_nodes, relation_def["children"]
+ )
+
+ return root_node
+
+ def resolve_subgraph(self) -> list:
+ """
+ Convert every result contained in this node set to a subgraph.
+
+ By default, we receive results from neomodel as a list of
+ nodes without the hierarchy. This method tries to rebuild this
+ hierarchy without overriding anything in the node, that's why
+ we use a dedicated property to store node's relations.
+
+ """
+ if (
+ self.relations_to_fetch
+ and not self.relations_to_fetch[0]["include_in_return"]
+ ):
+ raise NotImplementedError(
+ "You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead."
+ )
+ results: list = []
+ qbuilder = self.query_cls(self)
+ qbuilder.build_ast()
+ if not qbuilder._ast.subgraph:
+ raise RuntimeError(
+ "Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()."
+ )
+ other_nodes = {}
+ root_node = None
+ for row in qbuilder._execute(dict_output=True):
+ for name, node in row.items():
+ if node.__class__ is self.source and "_" not in name:
+ root_node = node
+ continue
+ if isinstance(node, list) and isinstance(node[0], list):
+ other_nodes[name] = node[0]
+ continue
+ other_nodes[name] = node
+ results.append(
+ self._to_subgraph(root_node, other_nodes, qbuilder._ast.subgraph)
+ )
+ return results
+
+ def subquery(self, nodeset: "NodeSet", return_set: List[str]) -> "NodeSet":
+ """Add a subquery to this node set.
+
+ A subquery is a regular cypher query but executed within the context of a CALL
+ statement. Such query will generally fetch additional variables which must be
+ declared inside return_set variable in order to be included in the final RETURN
+ statement.
+ """
+ qbuilder = nodeset.query_cls(nodeset, subquery_context=True).build_ast()
+ for var in return_set:
+ if (
+ var != qbuilder._ast.return_clause
+ and var not in qbuilder._ast.additional_return
+ and var
+ not in [res["alias"] for res in nodeset._extra_results if res["alias"]]
+ ):
+ raise RuntimeError(f"Variable '{var}' is not returned by subquery.")
+ self._subqueries.append((qbuilder.build_query(), return_set))
+ return self
+
+ def intermediate_transform(
+ self, vars: Dict[str, Any], ordering: TOptional[list] = None
+ ) -> "NodeSet":
+ if not vars:
+ raise ValueError(
+ "You must provide one variable at least when calling intermediate_transform()"
+ )
+ for name, source in vars.items():
+ if type(source) is not str and not isinstance(
+ source, (NodeNameResolver, RelationNameResolver)
+ ):
+ raise ValueError(
+ f"Wrong source type specified for variable '{name}', should be a string or an instance of NodeNameResolver or RelationNameResolver"
+ )
+ self._intermediate_transforms.append({"vars": vars, "ordering": ordering})
+ return self
+
class Traversal(BaseSet):
"""
@@ -1083,13 +1565,13 @@ class Traversal(BaseSet):
:type name: :class:`str`
:param definition: A relationship definition that most certainly deserves
a documentation here.
- :type defintion: :class:`dict`
+ :type definition: :class:`dict`
"""
def __await__(self):
return self.all().__await__()
- def __init__(self, source, name, definition):
+ def __init__(self, source, name, definition) -> None:
"""
Create a traversal
@@ -1119,7 +1601,7 @@ def __init__(self, source, name, definition):
self.definition = definition
self.target_class = definition["node_class"]
self.name = name
- self.filters = []
+ self.filters: List = []
def match(self, **kwargs):
"""
diff --git a/pyproject.toml b/pyproject.toml
index 3ad12f86..d72c546b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -22,9 +22,9 @@ classifiers = [
"Topic :: Database",
]
dependencies = [
- "neo4j~=5.19.0",
+ "neo4j~=5.26.0",
]
-requires-python = ">=3.7"
+requires-python = ">=3.8"
dynamic = ["version"]
[project.urls]
@@ -59,6 +59,7 @@ where = ["./"]
[tool.pytest.ini_options]
addopts = "--resetdb"
testpaths = "test"
+asyncio_default_fixture_loop_scope = "session"
[tool.isort]
profile = 'black'
diff --git a/requirements-dev.txt b/requirements-dev.txt
index bf3fa116..446dd8c1 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -3,6 +3,7 @@
unasync>=0.5.0
pytest>=7.1
+pytest-asyncio>=0.19.0
pytest-cov>=4.0
pre-commit
black
diff --git a/requirements.txt b/requirements.txt
index ffbfe285..e7a3f522 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1 +1 @@
-neo4j~=5.19.0
+neo4j~=5.26.0
diff --git a/test/_async_compat/__init__.py b/test/_async_compat/__init__.py
index 342678c3..5bdc28e3 100644
--- a/test/_async_compat/__init__.py
+++ b/test/_async_compat/__init__.py
@@ -1,8 +1,10 @@
from .mark_decorator import (
AsyncTestDecorators,
TestDecorators,
+ mark_async_function_auto_fixture,
mark_async_session_auto_fixture,
mark_async_test,
+ mark_sync_function_auto_fixture,
mark_sync_session_auto_fixture,
mark_sync_test,
)
@@ -13,5 +15,7 @@
"mark_sync_test",
"TestDecorators",
"mark_async_session_auto_fixture",
+ "mark_async_function_auto_fixture",
"mark_sync_session_auto_fixture",
+ "mark_sync_function_auto_fixture",
]
diff --git a/test/_async_compat/mark_decorator.py b/test/_async_compat/mark_decorator.py
index a8c5eead..5d6050d8 100644
--- a/test/_async_compat/mark_decorator.py
+++ b/test/_async_compat/mark_decorator.py
@@ -1,9 +1,15 @@
import pytest
import pytest_asyncio
-mark_async_test = pytest.mark.asyncio
-mark_async_session_auto_fixture = pytest_asyncio.fixture(scope="session", autouse=True)
+mark_async_test = pytest.mark.asyncio(loop_scope="session")
+mark_async_session_auto_fixture = pytest_asyncio.fixture(
+ loop_scope="session", scope="session", autouse=True
+)
+mark_async_function_auto_fixture = pytest_asyncio.fixture(
+ loop_scope="session", autouse=True
+)
mark_sync_session_auto_fixture = pytest.fixture(scope="session", autouse=True)
+mark_sync_function_auto_fixture = pytest.fixture(autouse=True)
def mark_sync_test(f):
diff --git a/test/async_/conftest.py b/test/async_/conftest.py
index 493ff12c..8cbf952b 100644
--- a/test/async_/conftest.py
+++ b/test/async_/conftest.py
@@ -1,15 +1,15 @@
-import asyncio
import os
import warnings
-from test._async_compat import mark_async_session_auto_fixture
-
-import pytest
+from test._async_compat import (
+ mark_async_function_auto_fixture,
+ mark_async_session_auto_fixture,
+)
from neomodel import adb, config
@mark_async_session_auto_fixture
-async def setup_neo4j_session(request, event_loop):
+async def setup_neo4j_session(request):
"""
Provides initial connection to the database and sets up the rest of the test suite
@@ -44,17 +44,12 @@ async def setup_neo4j_session(request, event_loop):
await adb.cypher_query("GRANT ROLE publisher TO troygreene")
await adb.cypher_query("GRANT IMPERSONATE (troygreene) ON DBMS TO admin")
-
-@mark_async_session_auto_fixture
-async def cleanup(event_loop):
yield
+
await adb.close_connection()
-@pytest.fixture(scope="session")
-def event_loop():
- """Overrides pytest default function scoped event loop"""
- policy = asyncio.get_event_loop_policy()
- loop = policy.new_event_loop()
- yield loop
- loop.close()
+@mark_async_function_auto_fixture
+async def setUp():
+ await adb.cypher_query("MATCH (n) DETACH DELETE n")
+ yield
diff --git a/test/async_/test_issue283.py b/test/async_/test_issue283.py
index 8106a796..ddbd6808 100644
--- a/test/async_/test_issue283.py
+++ b/test/async_/test_issue283.py
@@ -122,10 +122,6 @@ async def test_automatic_result_resolution():
# TechnicalPerson (!NOT basePerson!)
assert type((await A.friends_with)[0]) is TechnicalPerson
- await A.delete()
- await B.delete()
- await C.delete()
-
@mark_async_test
async def test_recursive_automatic_result_resolution():
@@ -176,11 +172,6 @@ async def test_recursive_automatic_result_resolution():
# Assert that primitive data types remain primitive data types
assert issubclass(type(L[0][0][0][1][0][1][0][1][0][0]), basestring)
- await A.delete()
- await B.delete()
- await C.delete()
- await D.delete()
-
@mark_async_test
async def test_validation_with_inheritance_from_db():
@@ -240,12 +231,6 @@ async def test_validation_with_inheritance_from_db():
)
assert type((await D.friends_with)[0]) is PilotPerson
- await A.delete()
- await B.delete()
- await C.delete()
- await D.delete()
- await E.delete()
-
@mark_async_test
async def test_validation_enforcement_to_db():
@@ -295,13 +280,6 @@ async def test_validation_enforcement_to_db():
with pytest.raises(ValueError):
await A.friends_with.connect(F)
- await A.delete()
- await B.delete()
- await C.delete()
- await D.delete()
- await E.delete()
- await F.delete()
-
@mark_async_test
async def test_failed_result_resolution():
@@ -344,9 +322,6 @@ class RandomPerson(BasePerson):
for some_friend in friends:
print(some_friend.name)
- await A.delete()
- await B.delete()
-
@mark_async_test
async def test_node_label_mismatch():
@@ -509,6 +484,10 @@ async def test_resolve_inexistent_relationship():
Attempting to resolve an inexistent relationship should raise an exception
:return:
"""
+ A = await TechnicalPerson(name="Michael Knight", expertise="Cars").save()
+ B = await TechnicalPerson(name="Luke Duke", expertise="Lasers").save()
+
+ await A.friends_with.connect(B)
# Forget about the FRIENDS_WITH Relationship.
del adb._NODE_CLASS_REGISTRY[frozenset(["FRIENDS_WITH"])]
@@ -518,7 +497,7 @@ async def test_resolve_inexistent_relationship():
match=r"Relationship of type .* does not resolve to any of the known objects.*",
):
query_data = await adb.cypher_query(
- "MATCH (:ExtendedSomePerson)-[r:FRIENDS_WITH]->(:ExtendedSomePerson) "
+ "MATCH (:TechnicalPerson)-[r:FRIENDS_WITH]->(:TechnicalPerson) "
"RETURN DISTINCT r",
resolve_objects=True,
)
diff --git a/test/async_/test_issue600.py b/test/async_/test_issue600.py
index 5f66f39e..3cf4e870 100644
--- a/test/async_/test_issue600.py
+++ b/test/async_/test_issue600.py
@@ -63,11 +63,6 @@ async def test_relationship_definer_second_sibling():
await B.rel_2.connect(C)
await C.rel_3.connect(A)
- # Clean up
- await A.delete()
- await B.delete()
- await C.delete()
-
@mark_async_test
async def test_relationship_definer_parent_last():
@@ -80,8 +75,3 @@ async def test_relationship_definer_parent_last():
await A.rel_1.connect(B)
await B.rel_2.connect(C)
await C.rel_3.connect(A)
-
- # Clean up
- await A.delete()
- await B.delete()
- await C.delete()
diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py
index 4f549dd5..39e96957 100644
--- a/test/async_/test_match_api.py
+++ b/test/async_/test_match_api.py
@@ -1,3 +1,4 @@
+import re
from datetime import datetime
from test._async_compat import mark_async_test
@@ -10,18 +11,25 @@
AsyncRelationshipTo,
AsyncStructuredNode,
AsyncStructuredRel,
+ AsyncZeroOrOne,
DateTimeProperty,
IntegerProperty,
Q,
StringProperty,
UniqueIdProperty,
+ adb,
)
from neomodel._async_compat.util import AsyncUtil
from neomodel.async_.match import (
AsyncNodeSet,
AsyncQueryBuilder,
AsyncTraversal,
+ Collect,
+ Last,
+ NodeNameResolver,
Optional,
+ RawCypher,
+ RelationNameResolver,
)
from neomodel.exceptions import MultipleNodesReturned, RelationshipClassNotDefined
@@ -34,7 +42,7 @@ class SupplierRel(AsyncStructuredRel):
class Supplier(AsyncStructuredNode):
name = StringProperty()
delivery_cost = IntegerProperty()
- coffees = AsyncRelationshipTo("Coffee", "COFFEE SUPPLIERS")
+ coffees = AsyncRelationshipTo("Coffee", "COFFEE SUPPLIERS", model=SupplierRel)
class Species(AsyncStructuredNode):
@@ -79,6 +87,40 @@ class PersonX(AsyncStructuredNode):
city = AsyncRelationshipTo(CityX, "LIVES_IN")
+class SoftwareDependency(AsyncStructuredNode):
+ name = StringProperty(required=True)
+ version = StringProperty(required=True)
+
+
+class HasCourseRel(AsyncStructuredRel):
+ level = StringProperty()
+ start_date = DateTimeProperty()
+ end_date = DateTimeProperty()
+
+
+class Course(AsyncStructuredNode):
+ name = StringProperty()
+
+
+class Building(AsyncStructuredNode):
+ name = StringProperty()
+
+
+class Student(AsyncStructuredNode):
+ name = StringProperty()
+
+ parents = AsyncRelationshipTo("Student", "HAS_PARENT", model=AsyncStructuredRel)
+ children = AsyncRelationshipFrom("Student", "HAS_PARENT", model=AsyncStructuredRel)
+ lives_in = AsyncRelationshipTo(Building, "LIVES_IN", model=AsyncStructuredRel)
+ courses = AsyncRelationshipTo(Course, "HAS_COURSE", model=HasCourseRel)
+ preferred_course = AsyncRelationshipTo(
+ Course,
+ "HAS_PREFERRED_COURSE",
+ model=AsyncStructuredRel,
+ cardinality=AsyncZeroOrOne,
+ )
+
+
@mark_async_test
async def test_filter_exclude_via_labels():
await Coffee(name="Java", price=99).save()
@@ -144,7 +186,7 @@ async def test_get():
@mark_async_test
async def test_simple_traverse_with_filter():
nescafe = await Coffee(name="Nescafe2", price=99).save()
- tesco = await Supplier(name="Sainsburys", delivery_cost=2).save()
+ tesco = await Supplier(name="Tesco", delivery_cost=2).save()
await nescafe.suppliers.connect(tesco)
qb = AsyncQueryBuilder(
@@ -158,7 +200,7 @@ async def test_simple_traverse_with_filter():
assert qb._ast.match
assert qb._ast.return_clause.startswith("suppliers")
assert len(results) == 1
- assert results[0].name == "Sainsburys"
+ assert results[0].name == "Tesco"
@mark_async_test
@@ -211,9 +253,6 @@ async def test_len_and_iter_and_bool():
@mark_async_test
async def test_slice():
- for c in await Coffee.nodes:
- await c.delete()
-
await Coffee(name="Icelands finest").save()
await Coffee(name="Britains finest").save()
await Coffee(name="Japans finest").save()
@@ -282,9 +321,6 @@ async def test_contains():
@mark_async_test
async def test_order_by():
- for c in await Coffee.nodes:
- await c.delete()
-
c1 = await Coffee(name="Icelands finest", price=5).save()
c2 = await Coffee(name="Britains finest", price=10).save()
c3 = await Coffee(name="Japans finest", price=35).save()
@@ -305,13 +341,13 @@ async def test_order_by():
ns = ns.order_by("?")
qb = await AsyncQueryBuilder(ns).build_ast()
assert qb._ast.with_clause == "coffee, rand() as r"
- assert qb._ast.order_by == "r"
+ assert qb._ast.order_by == ["r"]
with raises(
ValueError,
match=r".*Neo4j internals like id or element_id are not allowed for use in this operation.",
):
- await Coffee.nodes.order_by("id")
+ await Coffee.nodes.order_by("id").all()
# Test order by on a relationship
l = await Supplier(name="lidl2").save()
@@ -326,10 +362,27 @@ async def test_order_by():
@mark_async_test
-async def test_extra_filters():
- for c in await Coffee.nodes:
- await c.delete()
+async def test_order_by_rawcypher():
+ d1 = await SoftwareDependency(name="Package1", version="1.0.0").save()
+ d2 = await SoftwareDependency(name="Package2", version="1.4.0").save()
+ d3 = await SoftwareDependency(name="Package3", version="2.5.5").save()
+ assert (
+ await SoftwareDependency.nodes.order_by(
+ RawCypher("toInteger(split($n.version, '.')[0]) DESC"),
+ ).all()
+ )[0] == d3
+
+ with raises(
+ ValueError, match=r"RawCypher: Do not include any action that has side effect"
+ ):
+ SoftwareDependency.nodes.order_by(
+ RawCypher("DETACH DELETE $n"),
+ )
+
+
+@mark_async_test
+async def test_extra_filters():
c1 = await Coffee(name="Icelands finest", price=5, id_=1).save()
c2 = await Coffee(name="Britains finest", price=10, id_=2).save()
c3 = await Coffee(name="Japans finest", price=35, id_=3).save()
@@ -401,10 +454,6 @@ async def test_empty_filters():
``get_queryset`` function in ``GenericAPIView`` should returns
``NodeSet`` object.
"""
-
- for c in await Coffee.nodes:
- await c.delete()
-
c1 = await Coffee(name="Super", price=5, id_=1).save()
c2 = await Coffee(name="Puper", price=10, id_=2).save()
@@ -428,10 +477,6 @@ async def test_empty_filters():
@mark_async_test
async def test_q_filters():
- # Test where no children and self.connector != conn ?
- for c in await Coffee.nodes:
- await c.delete()
-
c1 = await Coffee(name="Icelands finest", price=5, id_=1).save()
c2 = await Coffee(name="Britains finest", price=10, id_=2).save()
c3 = await Coffee(name="Japans finest", price=35, id_=3).save()
@@ -522,7 +567,7 @@ async def test_traversal_filter_left_hand_statement():
nescafe = await Coffee(name="Nescafe2", price=99).save()
nescafe_gold = await Coffee(name="Nescafe gold", price=11).save()
- tesco = await Supplier(name="Sainsburys", delivery_cost=3).save()
+ tesco = await Supplier(name="Tesco", delivery_cost=3).save()
biedronka = await Supplier(name="Biedronka", delivery_cost=5).save()
lidl = await Supplier(name="Lidl", delivery_cost=3).save()
@@ -539,20 +584,102 @@ async def test_traversal_filter_left_hand_statement():
assert lidl in lidl_supplier
+@mark_async_test
+async def test_filter_with_traversal():
+ arabica = await Species(name="Arabica").save()
+ robusta = await Species(name="Robusta").save()
+ nescafe = await Coffee(name="Nescafe", price=11).save()
+ nescafe_gold = await Coffee(name="Nescafe Gold", price=99).save()
+ tesco = await Supplier(name="Tesco", delivery_cost=3).save()
+ await nescafe.suppliers.connect(tesco)
+ await nescafe_gold.suppliers.connect(tesco)
+ await nescafe.species.connect(arabica)
+ await nescafe_gold.species.connect(robusta)
+
+ results = await Coffee.nodes.filter(species__name="Arabica").all()
+ assert len(results) == 1
+ assert len(results[0]) == 3
+ assert results[0][0] == nescafe
+
+ results_multi_hop = await Supplier.nodes.filter(
+ coffees__species__name="Arabica"
+ ).all()
+ assert len(results_multi_hop) == 1
+ assert results_multi_hop[0][0] == tesco
+
+ no_results = await Supplier.nodes.filter(coffees__species__name="Noffee").all()
+ assert no_results == []
+
+
+@mark_async_test
+async def test_relation_prop_filtering():
+ arabica = await Species(name="Arabica").save()
+ nescafe = await Coffee(name="Nescafe", price=99).save()
+ supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save()
+ supplier2 = await Supplier(name="Supplier 2", delivery_cost=20).save()
+
+ await nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
+ await nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)})
+ await nescafe.species.connect(arabica)
+
+ results = await Supplier.nodes.filter(
+ **{"coffees__name": "Nescafe", "coffees|since__gt": datetime(2018, 4, 1, 0, 0)}
+ ).all()
+
+ assert len(results) == 1
+ assert results[0][0] == supplier1
+
+ # Test it works with mixed argument syntaxes
+ results2 = await Supplier.nodes.filter(
+ name="Supplier 1",
+ coffees__name="Nescafe",
+ **{"coffees|since__gt": datetime(2018, 4, 1, 0, 0)},
+ ).all()
+
+ assert len(results2) == 1
+ assert results2[0][0] == supplier1
+
+
+@mark_async_test
+async def test_relation_prop_ordering():
+ arabica = await Species(name="Arabica").save()
+ nescafe = await Coffee(name="Nescafe", price=99).save()
+ supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save()
+ supplier2 = await Supplier(name="Supplier 2", delivery_cost=20).save()
+
+ await nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
+ await nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)})
+ await nescafe.species.connect(arabica)
+
+ results = (
+ await Supplier.nodes.fetch_relations("coffees").order_by("-coffees|since").all()
+ )
+ assert len(results) == 2
+ assert results[0][0] == supplier1
+ assert results[1][0] == supplier2
+
+ results = (
+ await Supplier.nodes.fetch_relations("coffees").order_by("coffees|since").all()
+ )
+ assert len(results) == 2
+ assert results[0][0] == supplier2
+ assert results[1][0] == supplier1
+
+
@mark_async_test
async def test_fetch_relations():
arabica = await Species(name="Arabica").save()
robusta = await Species(name="Robusta").save()
- nescafe = await Coffee(name="Nescafe 1000", price=99).save()
- nescafe_gold = await Coffee(name="Nescafe 1001", price=11).save()
+ nescafe = await Coffee(name="Nescafe", price=99).save()
+ nescafe_gold = await Coffee(name="Nescafe Gold", price=11).save()
- tesco = await Supplier(name="Sainsburys", delivery_cost=3).save()
+ tesco = await Supplier(name="Tesco", delivery_cost=3).save()
await nescafe.suppliers.connect(tesco)
await nescafe_gold.suppliers.connect(tesco)
await nescafe.species.connect(arabica)
result = (
- await Supplier.nodes.filter(name="Sainsburys")
+ await Supplier.nodes.filter(name="Tesco")
.fetch_relations("coffees__species")
.all()
)
@@ -568,11 +695,11 @@ async def test_fetch_relations():
.fetch_relations(Optional("coffees__suppliers"))
.all()
)
- assert result[0][0] is None
+ assert len(result) == 0
if AsyncUtil.is_async_code:
count = (
- await Supplier.nodes.filter(name="Sainsburys")
+ await Supplier.nodes.filter(name="Tesco")
.fetch_relations("coffees__species")
.get_len()
)
@@ -580,20 +707,337 @@ async def test_fetch_relations():
assert (
await Supplier.nodes.fetch_relations("coffees__species")
- .filter(name="Sainsburys")
+ .filter(name="Tesco")
.check_contains(tesco)
)
else:
count = len(
- Supplier.nodes.filter(name="Sainsburys")
+ Supplier.nodes.filter(name="Tesco")
.fetch_relations("coffees__species")
.all()
)
assert count == 1
assert tesco in Supplier.nodes.fetch_relations("coffees__species").filter(
- name="Sainsburys"
+ name="Tesco"
+ )
+
+
+@mark_async_test
+async def test_traverse_and_order_by():
+ arabica = await Species(name="Arabica").save()
+ robusta = await Species(name="Robusta").save()
+ nescafe = await Coffee(name="Nescafe", price=99).save()
+ nescafe_gold = await Coffee(name="Nescafe Gold", price=110).save()
+ tesco = await Supplier(name="Tesco", delivery_cost=3).save()
+ await nescafe.suppliers.connect(tesco)
+ await nescafe_gold.suppliers.connect(tesco)
+ await nescafe.species.connect(arabica)
+ await nescafe_gold.species.connect(robusta)
+
+ results = (
+ await Species.nodes.fetch_relations("coffees").order_by("-coffees__price").all()
+ )
+ assert len(results) == 2
+ assert len(results[0]) == 3 # 2 nodes and 1 relation
+ assert results[0][0] == robusta
+ assert results[1][0] == arabica
+
+
+@mark_async_test
+async def test_annotate_and_collect():
+ arabica = await Species(name="Arabica").save()
+ robusta = await Species(name="Robusta").save()
+ nescafe = await Coffee(name="Nescafe 1002", price=99).save()
+ nescafe_gold = await Coffee(name="Nescafe 1003", price=11).save()
+
+ tesco = await Supplier(name="Tesco", delivery_cost=3).save()
+ await nescafe.suppliers.connect(tesco)
+ await nescafe_gold.suppliers.connect(tesco)
+ await nescafe.species.connect(arabica)
+ await nescafe_gold.species.connect(robusta)
+ await nescafe_gold.species.connect(arabica)
+
+ result = (
+ await Supplier.nodes.traverse_relations(species="coffees__species")
+ .annotate(Collect("species"))
+ .all()
+ )
+ assert len(result) == 1
+ assert len(result[0][1][0]) == 3 # 3 species must be there (with 2 duplicates)
+
+ result = (
+ await Supplier.nodes.traverse_relations(species="coffees__species")
+ .annotate(Collect("species", distinct=True))
+ .all()
+ )
+ assert len(result[0][1][0]) == 2 # 2 species must be there
+
+ result = (
+ await Supplier.nodes.traverse_relations(species="coffees__species")
+ .annotate(all_species=Collect("species", distinct=True))
+ .all()
+ )
+ assert len(result[0][1][0]) == 2 # 2 species must be there
+
+ result = (
+ await Supplier.nodes.traverse_relations("coffees__species")
+ .annotate(
+ all_species=Collect(NodeNameResolver("coffees__species"), distinct=True),
+ all_species_rels=Collect(
+ RelationNameResolver("coffees__species"), distinct=True
+ ),
)
+ .all()
+ )
+ assert len(result[0][1][0]) == 2 # 2 species must be there
+ assert len(result[0][2][0]) == 3 # 3 species relations must be there
+
+
+@mark_async_test
+async def test_resolve_subgraph():
+ arabica = await Species(name="Arabica").save()
+ robusta = await Species(name="Robusta").save()
+ nescafe = await Coffee(name="Nescafe", price=99).save()
+ nescafe_gold = await Coffee(name="Nescafe Gold", price=11).save()
+
+ tesco = await Supplier(name="Tesco", delivery_cost=3).save()
+ await nescafe.suppliers.connect(tesco)
+ await nescafe_gold.suppliers.connect(tesco)
+ await nescafe.species.connect(arabica)
+ await nescafe_gold.species.connect(robusta)
+
+ with raises(
+ RuntimeError,
+ match=re.escape(
+ "Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()."
+ ),
+ ):
+ result = await Supplier.nodes.resolve_subgraph()
+
+ with raises(
+ NotImplementedError,
+ match=re.escape(
+ "You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead."
+ ),
+ ):
+ result = await Supplier.nodes.traverse_relations(
+ "coffees__species"
+ ).resolve_subgraph()
+
+ result = await Supplier.nodes.fetch_relations("coffees__species").resolve_subgraph()
+ assert len(result) == 2
+
+ assert hasattr(result[0], "_relations")
+ assert "coffees" in result[0]._relations
+ coffees = result[0]._relations["coffees"]
+ assert hasattr(coffees, "_relations")
+ assert "species" in coffees._relations
+
+ assert hasattr(result[1], "_relations")
+ assert "coffees" in result[1]._relations
+ coffees = result[1]._relations["coffees"]
+ assert hasattr(coffees, "_relations")
+ assert "species" in coffees._relations
+
+
+@mark_async_test
+async def test_resolve_subgraph_optional():
+ arabica = await Species(name="Arabica").save()
+ nescafe = await Coffee(name="Nescafe", price=99).save()
+ nescafe_gold = await Coffee(name="Nescafe Gold", price=11).save()
+
+ tesco = await Supplier(name="Tesco", delivery_cost=3).save()
+ await nescafe.suppliers.connect(tesco)
+ await nescafe_gold.suppliers.connect(tesco)
+ await nescafe.species.connect(arabica)
+
+ result = await Supplier.nodes.fetch_relations(
+ Optional("coffees__species")
+ ).resolve_subgraph()
+ assert len(result) == 1
+
+ assert hasattr(result[0], "_relations")
+ assert "coffees" in result[0]._relations
+ coffees = result[0]._relations["coffees"]
+ assert hasattr(coffees, "_relations")
+ assert "species" in coffees._relations
+ assert coffees._relations["species"] == arabica
+
+
+@mark_async_test
+async def test_subquery():
+ arabica = await Species(name="Arabica").save()
+ nescafe = await Coffee(name="Nescafe", price=99).save()
+ supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save()
+ supplier2 = await Supplier(name="Supplier 2", delivery_cost=20).save()
+
+ await nescafe.suppliers.connect(supplier1)
+ await nescafe.suppliers.connect(supplier2)
+ await nescafe.species.connect(arabica)
+
+ result = await Coffee.nodes.subquery(
+ Coffee.nodes.traverse_relations(suppliers="suppliers")
+ .intermediate_transform(
+ {"suppliers": "suppliers"}, ordering=["suppliers.delivery_cost"]
+ )
+ .annotate(supps=Last(Collect("suppliers"))),
+ ["supps"],
+ )
+ result = await result.all()
+ assert len(result) == 1
+ assert len(result[0]) == 2
+ assert result[0][0] == supplier2
+
+ with raises(
+ RuntimeError,
+ match=re.escape("Variable 'unknown' is not returned by subquery."),
+ ):
+ result = await Coffee.nodes.subquery(
+ Coffee.nodes.traverse_relations(suppliers="suppliers").annotate(
+ supps=Collect("suppliers")
+ ),
+ ["unknown"],
+ )
+
+
+@mark_async_test
+async def test_intermediate_transform():
+ arabica = await Species(name="Arabica").save()
+ nescafe = await Coffee(name="Nescafe", price=99).save()
+ supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save()
+ supplier2 = await Supplier(name="Supplier 2", delivery_cost=20).save()
+
+ await nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
+ await nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)})
+ await nescafe.species.connect(arabica)
+
+ result = (
+ await Coffee.nodes.fetch_relations("suppliers")
+ .intermediate_transform(
+ {
+ "coffee": "coffee",
+ "suppliers": NodeNameResolver("suppliers"),
+ "r": RelationNameResolver("suppliers"),
+ },
+ ordering=["-r.since"],
+ )
+ .annotate(oldest_supplier=Last(Collect("suppliers")))
+ .all()
+ )
+
+ assert len(result) == 1
+ assert result[0] == supplier2
+
+ with raises(
+ ValueError,
+ match=re.escape(
+ r"Wrong source type specified for variable 'test', should be a string or an instance of NodeNameResolver or RelationNameResolver"
+ ),
+ ):
+ Coffee.nodes.traverse_relations(suppliers="suppliers").intermediate_transform(
+ {
+ "test": Collect("suppliers"),
+ }
+ )
+ with raises(
+ ValueError,
+ match=re.escape(
+ r"You must provide one variable at least when calling intermediate_transform()"
+ ),
+ ):
+ Coffee.nodes.traverse_relations(suppliers="suppliers").intermediate_transform(
+ {}
+ )
+
+
+@mark_async_test
+async def test_mix_functions():
+ # Test with a mix of all advanced querying functions
+
+ eiffel_tower = await Building(name="Eiffel Tower").save()
+ empire_state_building = await Building(name="Empire State Building").save()
+ miranda = await Student(name="Miranda").save()
+ await miranda.lives_in.connect(empire_state_building)
+ jean_pierre = await Student(name="Jean-Pierre").save()
+ await jean_pierre.lives_in.connect(eiffel_tower)
+ mireille = await Student(name="Mireille").save()
+ mimoun_jr = await Student(name="Mimoun Jr").save()
+ mimoun = await Student(name="Mimoun").save()
+ await mireille.lives_in.connect(eiffel_tower)
+ await mimoun_jr.lives_in.connect(eiffel_tower)
+ await mimoun.lives_in.connect(eiffel_tower)
+ await mimoun.parents.connect(mireille)
+ await mimoun.children.connect(mimoun_jr)
+ math = await Course(name="Math").save()
+ dessin = await Course(name="Dessin").save()
+ await mimoun.courses.connect(
+ math,
+ {
+ "level": "1.2",
+ "start_date": datetime(2020, 6, 2),
+ "end_date": datetime(2020, 12, 31),
+ },
+ )
+ await mimoun.courses.connect(
+ math,
+ {
+ "level": "1.1",
+ "start_date": datetime(2020, 1, 1),
+ "end_date": datetime(2020, 6, 1),
+ },
+ )
+ await mimoun_jr.courses.connect(
+ math,
+ {
+ "level": "1.1",
+ "start_date": datetime(2020, 1, 1),
+ "end_date": datetime(2020, 6, 1),
+ },
+ )
+
+ await mimoun_jr.preferred_course.connect(dessin)
+
+ full_nodeset = (
+ await Student.nodes.filter(name__istartswith="m", lives_in__name="Eiffel Tower")
+ .order_by("name")
+ .fetch_relations(
+ "parents",
+ Optional("children__preferred_course"),
+ )
+ .subquery(
+ Student.nodes.fetch_relations("courses")
+ .intermediate_transform(
+ {"rel": RelationNameResolver("courses")},
+ ordering=[
+ RawCypher("toInteger(split(rel.level, '.')[0])"),
+ RawCypher("toInteger(split(rel.level, '.')[1])"),
+ "rel.end_date",
+ "rel.start_date",
+ ],
+ )
+ .annotate(
+ latest_course=Last(Collect("rel")),
+ ),
+ ["latest_course"],
+ )
+ )
+
+ subgraph = await full_nodeset.annotate(
+ children=Collect(NodeNameResolver("children"), distinct=True),
+ children_preferred_course=Collect(
+ NodeNameResolver("children__preferred_course"), distinct=True
+ ),
+ ).resolve_subgraph()
+
+ assert len(subgraph) == 2
+ assert subgraph[0] == mimoun
+ assert subgraph[1] == mimoun_jr
+ mimoun_returned_rels = subgraph[0]._relations
+ assert mimoun_returned_rels["children"] == mimoun_jr
+ assert mimoun_returned_rels["children"]._relations["preferred_course"] == dessin
+ assert mimoun_returned_rels["parents"] == mireille
+ assert mimoun_returned_rels["latest_course_relationship"].level == "1.2"
@mark_async_test
@@ -639,9 +1083,6 @@ async def test_in_filter_with_array_property():
async def test_async_iterator():
n = 10
if AsyncUtil.is_async_code:
- for c in await Coffee.nodes:
- await c.delete()
-
for i in range(n):
await Coffee(name=f"xxx_{i}", price=i).save()
diff --git a/test/async_/test_paths.py b/test/async_/test_paths.py
index 59a5e385..f0599e01 100644
--- a/test/async_/test_paths.py
+++ b/test/async_/test_paths.py
@@ -85,12 +85,3 @@ async def test_path_instantiation():
assert type(path_rels[0]) is PersonLivesInCity
assert type(path_rels[1]) is AsyncStructuredRel
-
- await c1.delete()
- await c2.delete()
- await ct1.delete()
- await ct2.delete()
- await p1.delete()
- await p2.delete()
- await p3.delete()
- await p4.delete()
diff --git a/test/async_/test_properties.py b/test/async_/test_properties.py
index 4f3eab2d..646f2cd5 100644
--- a/test/async_/test_properties.py
+++ b/test/async_/test_properties.py
@@ -309,6 +309,21 @@ def test_json():
assert prop.deflate(value) == '{"test": [1, 2, 3]}'
assert prop.inflate('{"test": [1, 2, 3]}') == value
+ value_with_unicode = {"test": [1, 2, 3, "©"]}
+ assert prop.deflate(value_with_unicode) == '{"test": [1, 2, 3, "\\u00a9"]}'
+ assert prop.inflate('{"test": [1, 2, 3, "\\u00a9"]}') == value_with_unicode
+
+
+def test_json_unicode():
+ prop = JSONProperty(ensure_ascii=False)
+ prop.name = "json"
+ prop.owner = FooBar
+
+ value = {"test": [1, 2, 3, "©"]}
+
+ assert prop.deflate(value) == '{"test": [1, 2, 3, "©"]}'
+ assert prop.inflate('{"test": [1, 2, 3, "©"]}') == value
+
def test_indexed():
indexed = StringProperty(index=True)
@@ -418,10 +433,6 @@ async def test_independent_property_name():
rel = await x.knows.relationship(x)
assert rel.known_for == r.known_for
- # -- cleanup --
-
- await x.delete()
-
@mark_async_test
async def test_independent_property_name_for_semi_structured():
@@ -455,8 +466,6 @@ class TestDBNamePropertySemiStructuredNode(AsyncSemiStructuredNode):
# assert not hasattr(from_get, "title")
assert from_get.extra == "data"
- await semi.delete()
-
@mark_async_test
async def test_independent_property_name_get_or_create():
@@ -475,9 +484,6 @@ class TestNode(AsyncStructuredNode):
assert node_properties["name"] == "jim"
assert "name_" not in node_properties
- # delete node afterwards
- await x.delete()
-
@mark.parametrize("normalized_class", (NormalizedProperty,))
def test_normalized_property(normalized_class):
@@ -648,9 +654,6 @@ class ConstrainedTestNode(AsyncStructuredNode):
node_properties = get_graph_entity_properties(results[0][0])
assert node_properties["unique_required_property"] == "unique and required"
- # delete node afterwards
- await x.delete()
-
@mark_async_test
async def test_unique_index_prop_enforced():
@@ -675,11 +678,6 @@ class UniqueNullableNameNode(AsyncStructuredNode):
results, _ = await adb.cypher_query("MATCH (n:UniqueNullableNameNode) RETURN n")
assert len(results) == 3
- # Delete nodes afterwards
- await x.delete()
- await y.delete()
- await z.delete()
-
def test_alias_property():
class AliasedClass(AsyncStructuredNode):
diff --git a/test/async_/test_transactions.py b/test/async_/test_transactions.py
index 59d523c5..de7a13e5 100644
--- a/test/async_/test_transactions.py
+++ b/test/async_/test_transactions.py
@@ -14,9 +14,6 @@ class APerson(AsyncStructuredNode):
@mark_async_test
async def test_rollback_and_commit_transaction():
- for p in await APerson.nodes:
- await p.delete()
-
await APerson(name="Roger").save()
await adb.begin()
@@ -41,8 +38,6 @@ async def in_a_tx(*names):
@mark_async_test
async def test_transaction_decorator():
await adb.install_labels(APerson)
- for p in await APerson.nodes:
- await p.delete()
# should work
await in_a_tx("Roger")
@@ -68,9 +63,6 @@ async def test_transaction_as_a_context():
@mark_async_test
async def test_query_inside_transaction():
- for p in await APerson.nodes:
- await p.delete()
-
async with adb.transaction:
await APerson(name="Alice").save()
await APerson(name="Bob").save()
@@ -119,9 +111,6 @@ async def in_a_tx_with_bookmark(*names):
@mark_async_test
async def test_bookmark_transaction_decorator():
- for p in await APerson.nodes:
- await p.delete()
-
# should work
result, bookmarks = await in_a_tx_with_bookmark("Ruth", bookmarks=None)
assert result is None
@@ -181,9 +170,6 @@ async def test_bookmark_passed_in_to_context(spy_on_db_begin):
@mark_async_test
async def test_query_inside_bookmark_transaction():
- for p in await APerson.nodes:
- await p.delete()
-
async with adb.transaction as transaction:
await APerson(name="Alice").save()
await APerson(name="Bob").save()
diff --git a/test/sync_/conftest.py b/test/sync_/conftest.py
index d2cd787e..cbe38140 100644
--- a/test/sync_/conftest.py
+++ b/test/sync_/conftest.py
@@ -1,6 +1,9 @@
import os
import warnings
-from test._async_compat import mark_sync_session_auto_fixture
+from test._async_compat import (
+ mark_async_function_auto_fixture,
+ mark_sync_session_auto_fixture,
+)
from neomodel import config, db
@@ -41,8 +44,12 @@ def setup_neo4j_session(request):
db.cypher_query("GRANT ROLE publisher TO troygreene")
db.cypher_query("GRANT IMPERSONATE (troygreene) ON DBMS TO admin")
-
-@mark_sync_session_auto_fixture
-def cleanup():
yield
+
db.close_connection()
+
+
+@mark_async_function_auto_fixture
+def setUp():
+ db.cypher_query("MATCH (n) DETACH DELETE n")
+ yield
diff --git a/test/sync_/test_issue283.py b/test/sync_/test_issue283.py
index a059f7f2..fab4f0d7 100644
--- a/test/sync_/test_issue283.py
+++ b/test/sync_/test_issue283.py
@@ -9,6 +9,7 @@
idea remains the same: "Instantiate the correct type of node at the end of
a relationship as specified by the model"
"""
+
import random
from test._async_compat import mark_sync_test
@@ -116,10 +117,6 @@ def test_automatic_result_resolution():
# TechnicalPerson (!NOT basePerson!)
assert type((A.friends_with)[0]) is TechnicalPerson
- A.delete()
- B.delete()
- C.delete()
-
@mark_sync_test
def test_recursive_automatic_result_resolution():
@@ -158,11 +155,6 @@ def test_recursive_automatic_result_resolution():
# Assert that primitive data types remain primitive data types
assert issubclass(type(L[0][0][0][1][0][1][0][1][0][0]), basestring)
- A.delete()
- B.delete()
- C.delete()
- D.delete()
-
@mark_sync_test
def test_validation_with_inheritance_from_db():
@@ -216,12 +208,6 @@ def test_validation_with_inheritance_from_db():
)
assert type((D.friends_with)[0]) is PilotPerson
- A.delete()
- B.delete()
- C.delete()
- D.delete()
- E.delete()
-
@mark_sync_test
def test_validation_enforcement_to_db():
@@ -265,13 +251,6 @@ def test_validation_enforcement_to_db():
with pytest.raises(ValueError):
A.friends_with.connect(F)
- A.delete()
- B.delete()
- C.delete()
- D.delete()
- E.delete()
- F.delete()
-
@mark_sync_test
def test_failed_result_resolution():
@@ -310,9 +289,6 @@ class RandomPerson(BasePerson):
for some_friend in friends:
print(some_friend.name)
- A.delete()
- B.delete()
-
@mark_sync_test
def test_node_label_mismatch():
@@ -469,6 +445,10 @@ def test_resolve_inexistent_relationship():
Attempting to resolve an inexistent relationship should raise an exception
:return:
"""
+ A = TechnicalPerson(name="Michael Knight", expertise="Cars").save()
+ B = TechnicalPerson(name="Luke Duke", expertise="Lasers").save()
+
+ A.friends_with.connect(B)
# Forget about the FRIENDS_WITH Relationship.
del db._NODE_CLASS_REGISTRY[frozenset(["FRIENDS_WITH"])]
@@ -478,7 +458,7 @@ def test_resolve_inexistent_relationship():
match=r"Relationship of type .* does not resolve to any of the known objects.*",
):
query_data = db.cypher_query(
- "MATCH (:ExtendedSomePerson)-[r:FRIENDS_WITH]->(:ExtendedSomePerson) "
+ "MATCH (:TechnicalPerson)-[r:FRIENDS_WITH]->(:TechnicalPerson) "
"RETURN DISTINCT r",
resolve_objects=True,
)
diff --git a/test/sync_/test_issue600.py b/test/sync_/test_issue600.py
index f6b5a10b..181a156d 100644
--- a/test/sync_/test_issue600.py
+++ b/test/sync_/test_issue600.py
@@ -63,11 +63,6 @@ def test_relationship_definer_second_sibling():
B.rel_2.connect(C)
C.rel_3.connect(A)
- # Clean up
- A.delete()
- B.delete()
- C.delete()
-
@mark_sync_test
def test_relationship_definer_parent_last():
@@ -80,8 +75,3 @@ def test_relationship_definer_parent_last():
A.rel_1.connect(B)
B.rel_2.connect(C)
C.rel_3.connect(A)
-
- # Clean up
- A.delete()
- B.delete()
- C.delete()
diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py
index d9c90bb9..78909860 100644
--- a/test/sync_/test_match_api.py
+++ b/test/sync_/test_match_api.py
@@ -1,3 +1,4 @@
+import re
from datetime import datetime
from test._async_compat import mark_sync_test
@@ -15,10 +16,22 @@
StructuredNode,
StructuredRel,
UniqueIdProperty,
+ ZeroOrOne,
+ db,
)
from neomodel._async_compat.util import Util
from neomodel.exceptions import MultipleNodesReturned, RelationshipClassNotDefined
-from neomodel.sync_.match import NodeSet, Optional, QueryBuilder, Traversal
+from neomodel.sync_.match import (
+ Collect,
+ Last,
+ NodeNameResolver,
+ NodeSet,
+ Optional,
+ QueryBuilder,
+ RawCypher,
+ RelationNameResolver,
+ Traversal,
+)
class SupplierRel(StructuredRel):
@@ -29,7 +42,7 @@ class SupplierRel(StructuredRel):
class Supplier(StructuredNode):
name = StringProperty()
delivery_cost = IntegerProperty()
- coffees = RelationshipTo("Coffee", "COFFEE SUPPLIERS")
+ coffees = RelationshipTo("Coffee", "COFFEE SUPPLIERS", model=SupplierRel)
class Species(StructuredNode):
@@ -72,6 +85,40 @@ class PersonX(StructuredNode):
city = RelationshipTo(CityX, "LIVES_IN")
+class SoftwareDependency(StructuredNode):
+ name = StringProperty(required=True)
+ version = StringProperty(required=True)
+
+
+class HasCourseRel(StructuredRel):
+ level = StringProperty()
+ start_date = DateTimeProperty()
+ end_date = DateTimeProperty()
+
+
+class Course(StructuredNode):
+ name = StringProperty()
+
+
+class Building(StructuredNode):
+ name = StringProperty()
+
+
+class Student(StructuredNode):
+ name = StringProperty()
+
+ parents = RelationshipTo("Student", "HAS_PARENT", model=StructuredRel)
+ children = RelationshipFrom("Student", "HAS_PARENT", model=StructuredRel)
+ lives_in = RelationshipTo(Building, "LIVES_IN", model=StructuredRel)
+ courses = RelationshipTo(Course, "HAS_COURSE", model=HasCourseRel)
+ preferred_course = RelationshipTo(
+ Course,
+ "HAS_PREFERRED_COURSE",
+ model=StructuredRel,
+ cardinality=ZeroOrOne,
+ )
+
+
@mark_sync_test
def test_filter_exclude_via_labels():
Coffee(name="Java", price=99).save()
@@ -137,7 +184,7 @@ def test_get():
@mark_sync_test
def test_simple_traverse_with_filter():
nescafe = Coffee(name="Nescafe2", price=99).save()
- tesco = Supplier(name="Sainsburys", delivery_cost=2).save()
+ tesco = Supplier(name="Tesco", delivery_cost=2).save()
nescafe.suppliers.connect(tesco)
qb = QueryBuilder(NodeSet(source=nescafe).suppliers.match(since__lt=datetime.now()))
@@ -149,7 +196,7 @@ def test_simple_traverse_with_filter():
assert qb._ast.match
assert qb._ast.return_clause.startswith("suppliers")
assert len(results) == 1
- assert results[0].name == "Sainsburys"
+ assert results[0].name == "Tesco"
@mark_sync_test
@@ -202,9 +249,6 @@ def test_len_and_iter_and_bool():
@mark_sync_test
def test_slice():
- for c in Coffee.nodes:
- c.delete()
-
Coffee(name="Icelands finest").save()
Coffee(name="Britains finest").save()
Coffee(name="Japans finest").save()
@@ -273,9 +317,6 @@ def test_contains():
@mark_sync_test
def test_order_by():
- for c in Coffee.nodes:
- c.delete()
-
c1 = Coffee(name="Icelands finest", price=5).save()
c2 = Coffee(name="Britains finest", price=10).save()
c3 = Coffee(name="Japans finest", price=35).save()
@@ -296,13 +337,13 @@ def test_order_by():
ns = ns.order_by("?")
qb = QueryBuilder(ns).build_ast()
assert qb._ast.with_clause == "coffee, rand() as r"
- assert qb._ast.order_by == "r"
+ assert qb._ast.order_by == ["r"]
with raises(
ValueError,
match=r".*Neo4j internals like id or element_id are not allowed for use in this operation.",
):
- Coffee.nodes.order_by("id")
+ Coffee.nodes.order_by("id").all()
# Test order by on a relationship
l = Supplier(name="lidl2").save()
@@ -317,10 +358,27 @@ def test_order_by():
@mark_sync_test
-def test_extra_filters():
- for c in Coffee.nodes:
- c.delete()
+def test_order_by_rawcypher():
+ d1 = SoftwareDependency(name="Package1", version="1.0.0").save()
+ d2 = SoftwareDependency(name="Package2", version="1.4.0").save()
+ d3 = SoftwareDependency(name="Package3", version="2.5.5").save()
+
+ assert (
+ SoftwareDependency.nodes.order_by(
+ RawCypher("toInteger(split($n.version, '.')[0]) DESC"),
+ ).all()
+ )[0] == d3
+
+ with raises(
+ ValueError, match=r"RawCypher: Do not include any action that has side effect"
+ ):
+ SoftwareDependency.nodes.order_by(
+ RawCypher("DETACH DELETE $n"),
+ )
+
+@mark_sync_test
+def test_extra_filters():
c1 = Coffee(name="Icelands finest", price=5, id_=1).save()
c2 = Coffee(name="Britains finest", price=10, id_=2).save()
c3 = Coffee(name="Japans finest", price=35, id_=3).save()
@@ -392,10 +450,6 @@ def test_empty_filters():
``get_queryset`` function in ``GenericAPIView`` should returns
``NodeSet`` object.
"""
-
- for c in Coffee.nodes:
- c.delete()
-
c1 = Coffee(name="Super", price=5, id_=1).save()
c2 = Coffee(name="Puper", price=10, id_=2).save()
@@ -419,10 +473,6 @@ def test_empty_filters():
@mark_sync_test
def test_q_filters():
- # Test where no children and self.connector != conn ?
- for c in Coffee.nodes:
- c.delete()
-
c1 = Coffee(name="Icelands finest", price=5, id_=1).save()
c2 = Coffee(name="Britains finest", price=10, id_=2).save()
c3 = Coffee(name="Japans finest", price=35, id_=3).save()
@@ -513,7 +563,7 @@ def test_traversal_filter_left_hand_statement():
nescafe = Coffee(name="Nescafe2", price=99).save()
nescafe_gold = Coffee(name="Nescafe gold", price=11).save()
- tesco = Supplier(name="Sainsburys", delivery_cost=3).save()
+ tesco = Supplier(name="Tesco", delivery_cost=3).save()
biedronka = Supplier(name="Biedronka", delivery_cost=5).save()
lidl = Supplier(name="Lidl", delivery_cost=3).save()
@@ -528,22 +578,96 @@ def test_traversal_filter_left_hand_statement():
assert lidl in lidl_supplier
+@mark_sync_test
+def test_filter_with_traversal():
+ arabica = Species(name="Arabica").save()
+ robusta = Species(name="Robusta").save()
+ nescafe = Coffee(name="Nescafe", price=11).save()
+ nescafe_gold = Coffee(name="Nescafe Gold", price=99).save()
+ tesco = Supplier(name="Tesco", delivery_cost=3).save()
+ nescafe.suppliers.connect(tesco)
+ nescafe_gold.suppliers.connect(tesco)
+ nescafe.species.connect(arabica)
+ nescafe_gold.species.connect(robusta)
+
+ results = Coffee.nodes.filter(species__name="Arabica").all()
+ assert len(results) == 1
+ assert len(results[0]) == 3
+ assert results[0][0] == nescafe
+
+ results_multi_hop = Supplier.nodes.filter(coffees__species__name="Arabica").all()
+ assert len(results_multi_hop) == 1
+ assert results_multi_hop[0][0] == tesco
+
+ no_results = Supplier.nodes.filter(coffees__species__name="Noffee").all()
+ assert no_results == []
+
+
+@mark_sync_test
+def test_relation_prop_filtering():
+ arabica = Species(name="Arabica").save()
+ nescafe = Coffee(name="Nescafe", price=99).save()
+ supplier1 = Supplier(name="Supplier 1", delivery_cost=3).save()
+ supplier2 = Supplier(name="Supplier 2", delivery_cost=20).save()
+
+ nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
+ nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)})
+ nescafe.species.connect(arabica)
+
+ results = Supplier.nodes.filter(
+ **{"coffees__name": "Nescafe", "coffees|since__gt": datetime(2018, 4, 1, 0, 0)}
+ ).all()
+
+ assert len(results) == 1
+ assert results[0][0] == supplier1
+
+ # Test it works with mixed argument syntaxes
+ results2 = Supplier.nodes.filter(
+ name="Supplier 1",
+ coffees__name="Nescafe",
+ **{"coffees|since__gt": datetime(2018, 4, 1, 0, 0)},
+ ).all()
+
+ assert len(results2) == 1
+ assert results2[0][0] == supplier1
+
+
+@mark_sync_test
+def test_relation_prop_ordering():
+ arabica = Species(name="Arabica").save()
+ nescafe = Coffee(name="Nescafe", price=99).save()
+ supplier1 = Supplier(name="Supplier 1", delivery_cost=3).save()
+ supplier2 = Supplier(name="Supplier 2", delivery_cost=20).save()
+
+ nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
+ nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)})
+ nescafe.species.connect(arabica)
+
+ results = Supplier.nodes.fetch_relations("coffees").order_by("-coffees|since").all()
+ assert len(results) == 2
+ assert results[0][0] == supplier1
+ assert results[1][0] == supplier2
+
+ results = Supplier.nodes.fetch_relations("coffees").order_by("coffees|since").all()
+ assert len(results) == 2
+ assert results[0][0] == supplier2
+ assert results[1][0] == supplier1
+
+
@mark_sync_test
def test_fetch_relations():
arabica = Species(name="Arabica").save()
robusta = Species(name="Robusta").save()
- nescafe = Coffee(name="Nescafe 1000", price=99).save()
- nescafe_gold = Coffee(name="Nescafe 1001", price=11).save()
+ nescafe = Coffee(name="Nescafe", price=99).save()
+ nescafe_gold = Coffee(name="Nescafe Gold", price=11).save()
- tesco = Supplier(name="Sainsburys", delivery_cost=3).save()
+ tesco = Supplier(name="Tesco", delivery_cost=3).save()
nescafe.suppliers.connect(tesco)
nescafe_gold.suppliers.connect(tesco)
nescafe.species.connect(arabica)
result = (
- Supplier.nodes.filter(name="Sainsburys")
- .fetch_relations("coffees__species")
- .all()
+ Supplier.nodes.filter(name="Tesco").fetch_relations("coffees__species").all()
)
assert len(result[0]) == 5
assert arabica in result[0]
@@ -557,11 +681,11 @@ def test_fetch_relations():
.fetch_relations(Optional("coffees__suppliers"))
.all()
)
- assert result[0][0] is None
+ assert len(result) == 0
if Util.is_async_code:
count = (
- Supplier.nodes.filter(name="Sainsburys")
+ Supplier.nodes.filter(name="Tesco")
.fetch_relations("coffees__species")
.__len__()
)
@@ -569,20 +693,335 @@ def test_fetch_relations():
assert (
Supplier.nodes.fetch_relations("coffees__species")
- .filter(name="Sainsburys")
+ .filter(name="Tesco")
.__contains__(tesco)
)
else:
count = len(
- Supplier.nodes.filter(name="Sainsburys")
+ Supplier.nodes.filter(name="Tesco")
.fetch_relations("coffees__species")
.all()
)
assert count == 1
assert tesco in Supplier.nodes.fetch_relations("coffees__species").filter(
- name="Sainsburys"
+ name="Tesco"
+ )
+
+
+@mark_sync_test
+def test_traverse_and_order_by():
+ arabica = Species(name="Arabica").save()
+ robusta = Species(name="Robusta").save()
+ nescafe = Coffee(name="Nescafe", price=99).save()
+ nescafe_gold = Coffee(name="Nescafe Gold", price=110).save()
+ tesco = Supplier(name="Tesco", delivery_cost=3).save()
+ nescafe.suppliers.connect(tesco)
+ nescafe_gold.suppliers.connect(tesco)
+ nescafe.species.connect(arabica)
+ nescafe_gold.species.connect(robusta)
+
+ results = Species.nodes.fetch_relations("coffees").order_by("-coffees__price").all()
+ assert len(results) == 2
+ assert len(results[0]) == 3 # 2 nodes and 1 relation
+ assert results[0][0] == robusta
+ assert results[1][0] == arabica
+
+
+@mark_sync_test
+def test_annotate_and_collect():
+ arabica = Species(name="Arabica").save()
+ robusta = Species(name="Robusta").save()
+ nescafe = Coffee(name="Nescafe 1002", price=99).save()
+ nescafe_gold = Coffee(name="Nescafe 1003", price=11).save()
+
+ tesco = Supplier(name="Tesco", delivery_cost=3).save()
+ nescafe.suppliers.connect(tesco)
+ nescafe_gold.suppliers.connect(tesco)
+ nescafe.species.connect(arabica)
+ nescafe_gold.species.connect(robusta)
+ nescafe_gold.species.connect(arabica)
+
+ result = (
+ Supplier.nodes.traverse_relations(species="coffees__species")
+ .annotate(Collect("species"))
+ .all()
+ )
+ assert len(result) == 1
+ assert len(result[0][1][0]) == 3 # 3 species must be there (with 2 duplicates)
+
+ result = (
+ Supplier.nodes.traverse_relations(species="coffees__species")
+ .annotate(Collect("species", distinct=True))
+ .all()
+ )
+ assert len(result[0][1][0]) == 2 # 2 species must be there
+
+ result = (
+ Supplier.nodes.traverse_relations(species="coffees__species")
+ .annotate(all_species=Collect("species", distinct=True))
+ .all()
+ )
+ assert len(result[0][1][0]) == 2 # 2 species must be there
+
+ result = (
+ Supplier.nodes.traverse_relations("coffees__species")
+ .annotate(
+ all_species=Collect(NodeNameResolver("coffees__species"), distinct=True),
+ all_species_rels=Collect(
+ RelationNameResolver("coffees__species"), distinct=True
+ ),
+ )
+ .all()
+ )
+ assert len(result[0][1][0]) == 2 # 2 species must be there
+ assert len(result[0][2][0]) == 3 # 3 species relations must be there
+
+
+@mark_sync_test
+def test_resolve_subgraph():
+ arabica = Species(name="Arabica").save()
+ robusta = Species(name="Robusta").save()
+ nescafe = Coffee(name="Nescafe", price=99).save()
+ nescafe_gold = Coffee(name="Nescafe Gold", price=11).save()
+
+ tesco = Supplier(name="Tesco", delivery_cost=3).save()
+ nescafe.suppliers.connect(tesco)
+ nescafe_gold.suppliers.connect(tesco)
+ nescafe.species.connect(arabica)
+ nescafe_gold.species.connect(robusta)
+
+ with raises(
+ RuntimeError,
+ match=re.escape(
+ "Nothing to resolve. Make sure to include relations in the result using fetch_relations() or filter()."
+ ),
+ ):
+ result = Supplier.nodes.resolve_subgraph()
+
+ with raises(
+ NotImplementedError,
+ match=re.escape(
+ "You cannot use traverse_relations() with resolve_subgraph(), use fetch_relations() instead."
+ ),
+ ):
+ result = Supplier.nodes.traverse_relations(
+ "coffees__species"
+ ).resolve_subgraph()
+
+ result = Supplier.nodes.fetch_relations("coffees__species").resolve_subgraph()
+ assert len(result) == 2
+
+ assert hasattr(result[0], "_relations")
+ assert "coffees" in result[0]._relations
+ coffees = result[0]._relations["coffees"]
+ assert hasattr(coffees, "_relations")
+ assert "species" in coffees._relations
+
+ assert hasattr(result[1], "_relations")
+ assert "coffees" in result[1]._relations
+ coffees = result[1]._relations["coffees"]
+ assert hasattr(coffees, "_relations")
+ assert "species" in coffees._relations
+
+
+@mark_sync_test
+def test_resolve_subgraph_optional():
+ arabica = Species(name="Arabica").save()
+ nescafe = Coffee(name="Nescafe", price=99).save()
+ nescafe_gold = Coffee(name="Nescafe Gold", price=11).save()
+
+ tesco = Supplier(name="Tesco", delivery_cost=3).save()
+ nescafe.suppliers.connect(tesco)
+ nescafe_gold.suppliers.connect(tesco)
+ nescafe.species.connect(arabica)
+
+ result = Supplier.nodes.fetch_relations(
+ Optional("coffees__species")
+ ).resolve_subgraph()
+ assert len(result) == 1
+
+ assert hasattr(result[0], "_relations")
+ assert "coffees" in result[0]._relations
+ coffees = result[0]._relations["coffees"]
+ assert hasattr(coffees, "_relations")
+ assert "species" in coffees._relations
+ assert coffees._relations["species"] == arabica
+
+
+@mark_sync_test
+def test_subquery():
+ arabica = Species(name="Arabica").save()
+ nescafe = Coffee(name="Nescafe", price=99).save()
+ supplier1 = Supplier(name="Supplier 1", delivery_cost=3).save()
+ supplier2 = Supplier(name="Supplier 2", delivery_cost=20).save()
+
+ nescafe.suppliers.connect(supplier1)
+ nescafe.suppliers.connect(supplier2)
+ nescafe.species.connect(arabica)
+
+ result = Coffee.nodes.subquery(
+ Coffee.nodes.traverse_relations(suppliers="suppliers")
+ .intermediate_transform(
+ {"suppliers": "suppliers"}, ordering=["suppliers.delivery_cost"]
+ )
+ .annotate(supps=Last(Collect("suppliers"))),
+ ["supps"],
+ )
+ result = result.all()
+ assert len(result) == 1
+ assert len(result[0]) == 2
+ assert result[0][0] == supplier2
+
+ with raises(
+ RuntimeError,
+ match=re.escape("Variable 'unknown' is not returned by subquery."),
+ ):
+ result = Coffee.nodes.subquery(
+ Coffee.nodes.traverse_relations(suppliers="suppliers").annotate(
+ supps=Collect("suppliers")
+ ),
+ ["unknown"],
+ )
+
+
+@mark_sync_test
+def test_intermediate_transform():
+ arabica = Species(name="Arabica").save()
+ nescafe = Coffee(name="Nescafe", price=99).save()
+ supplier1 = Supplier(name="Supplier 1", delivery_cost=3).save()
+ supplier2 = Supplier(name="Supplier 2", delivery_cost=20).save()
+
+ nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
+ nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)})
+ nescafe.species.connect(arabica)
+
+ result = (
+ Coffee.nodes.fetch_relations("suppliers")
+ .intermediate_transform(
+ {
+ "coffee": "coffee",
+ "suppliers": NodeNameResolver("suppliers"),
+ "r": RelationNameResolver("suppliers"),
+ },
+ ordering=["-r.since"],
+ )
+ .annotate(oldest_supplier=Last(Collect("suppliers")))
+ .all()
+ )
+
+ assert len(result) == 1
+ assert result[0] == supplier2
+
+ with raises(
+ ValueError,
+ match=re.escape(
+ r"Wrong source type specified for variable 'test', should be a string or an instance of NodeNameResolver or RelationNameResolver"
+ ),
+ ):
+ Coffee.nodes.traverse_relations(suppliers="suppliers").intermediate_transform(
+ {
+ "test": Collect("suppliers"),
+ }
+ )
+ with raises(
+ ValueError,
+ match=re.escape(
+ r"You must provide one variable at least when calling intermediate_transform()"
+ ),
+ ):
+ Coffee.nodes.traverse_relations(suppliers="suppliers").intermediate_transform(
+ {}
+ )
+
+
+@mark_sync_test
+def test_mix_functions():
+ # Test with a mix of all advanced querying functions
+
+ eiffel_tower = Building(name="Eiffel Tower").save()
+ empire_state_building = Building(name="Empire State Building").save()
+ miranda = Student(name="Miranda").save()
+ miranda.lives_in.connect(empire_state_building)
+ jean_pierre = Student(name="Jean-Pierre").save()
+ jean_pierre.lives_in.connect(eiffel_tower)
+ mireille = Student(name="Mireille").save()
+ mimoun_jr = Student(name="Mimoun Jr").save()
+ mimoun = Student(name="Mimoun").save()
+ mireille.lives_in.connect(eiffel_tower)
+ mimoun_jr.lives_in.connect(eiffel_tower)
+ mimoun.lives_in.connect(eiffel_tower)
+ mimoun.parents.connect(mireille)
+ mimoun.children.connect(mimoun_jr)
+ math = Course(name="Math").save()
+ dessin = Course(name="Dessin").save()
+ mimoun.courses.connect(
+ math,
+ {
+ "level": "1.2",
+ "start_date": datetime(2020, 6, 2),
+ "end_date": datetime(2020, 12, 31),
+ },
+ )
+ mimoun.courses.connect(
+ math,
+ {
+ "level": "1.1",
+ "start_date": datetime(2020, 1, 1),
+ "end_date": datetime(2020, 6, 1),
+ },
+ )
+ mimoun_jr.courses.connect(
+ math,
+ {
+ "level": "1.1",
+ "start_date": datetime(2020, 1, 1),
+ "end_date": datetime(2020, 6, 1),
+ },
+ )
+
+ mimoun_jr.preferred_course.connect(dessin)
+
+ full_nodeset = (
+ Student.nodes.filter(name__istartswith="m", lives_in__name="Eiffel Tower")
+ .order_by("name")
+ .fetch_relations(
+ "parents",
+ Optional("children__preferred_course"),
+ )
+ .subquery(
+ Student.nodes.fetch_relations("courses")
+ .intermediate_transform(
+ {"rel": RelationNameResolver("courses")},
+ ordering=[
+ RawCypher("toInteger(split(rel.level, '.')[0])"),
+ RawCypher("toInteger(split(rel.level, '.')[1])"),
+ "rel.end_date",
+ "rel.start_date",
+ ],
+ )
+ .annotate(
+ latest_course=Last(Collect("rel")),
+ ),
+ ["latest_course"],
)
+ )
+
+ subgraph = full_nodeset.annotate(
+ children=Collect(NodeNameResolver("children"), distinct=True),
+ children_preferred_course=Collect(
+ NodeNameResolver("children__preferred_course"), distinct=True
+ ),
+ ).resolve_subgraph()
+
+ assert len(subgraph) == 2
+ assert subgraph[0] == mimoun
+ assert subgraph[1] == mimoun_jr
+ mimoun_returned_rels = subgraph[0]._relations
+ assert mimoun_returned_rels["children"] == mimoun_jr
+ assert mimoun_returned_rels["children"]._relations["preferred_course"] == dessin
+ assert mimoun_returned_rels["parents"] == mireille
+ assert mimoun_returned_rels["latest_course_relationship"].level == "1.2"
@mark_sync_test
@@ -628,9 +1067,6 @@ def test_in_filter_with_array_property():
def test_async_iterator():
n = 10
if Util.is_async_code:
- for c in Coffee.nodes:
- c.delete()
-
for i in range(n):
Coffee(name=f"xxx_{i}", price=i).save()
diff --git a/test/sync_/test_paths.py b/test/sync_/test_paths.py
index 8e0ccf90..1a6429bf 100644
--- a/test/sync_/test_paths.py
+++ b/test/sync_/test_paths.py
@@ -85,12 +85,3 @@ def test_path_instantiation():
assert type(path_rels[0]) is PersonLivesInCity
assert type(path_rels[1]) is StructuredRel
-
- c1.delete()
- c2.delete()
- ct1.delete()
- ct2.delete()
- p1.delete()
- p2.delete()
- p3.delete()
- p4.delete()
diff --git a/test/sync_/test_properties.py b/test/sync_/test_properties.py
index 1afe52a2..0b0b576b 100644
--- a/test/sync_/test_properties.py
+++ b/test/sync_/test_properties.py
@@ -303,6 +303,21 @@ def test_json():
assert prop.deflate(value) == '{"test": [1, 2, 3]}'
assert prop.inflate('{"test": [1, 2, 3]}') == value
+ value_with_unicode = {"test": [1, 2, 3, "©"]}
+ assert prop.deflate(value_with_unicode) == '{"test": [1, 2, 3, "\\u00a9"]}'
+ assert prop.inflate('{"test": [1, 2, 3, "\\u00a9"]}') == value_with_unicode
+
+
+def test_json_unicode():
+ prop = JSONProperty(ensure_ascii=False)
+ prop.name = "json"
+ prop.owner = FooBar
+
+ value = {"test": [1, 2, 3, "©"]}
+
+ assert prop.deflate(value) == '{"test": [1, 2, 3, "©"]}'
+ assert prop.inflate('{"test": [1, 2, 3, "©"]}') == value
+
def test_indexed():
indexed = StringProperty(index=True)
@@ -408,10 +423,6 @@ def test_independent_property_name():
rel = x.knows.relationship(x)
assert rel.known_for == r.known_for
- # -- cleanup --
-
- x.delete()
-
@mark_sync_test
def test_independent_property_name_for_semi_structured():
@@ -445,8 +456,6 @@ class TestDBNamePropertySemiStructuredNode(SemiStructuredNode):
# assert not hasattr(from_get, "title")
assert from_get.extra == "data"
- semi.delete()
-
@mark_sync_test
def test_independent_property_name_get_or_create():
@@ -465,9 +474,6 @@ class TestNode(StructuredNode):
assert node_properties["name"] == "jim"
assert "name_" not in node_properties
- # delete node afterwards
- x.delete()
-
@mark.parametrize("normalized_class", (NormalizedProperty,))
def test_normalized_property(normalized_class):
@@ -638,9 +644,6 @@ class ConstrainedTestNode(StructuredNode):
node_properties = get_graph_entity_properties(results[0][0])
assert node_properties["unique_required_property"] == "unique and required"
- # delete node afterwards
- x.delete()
-
@mark_sync_test
def test_unique_index_prop_enforced():
@@ -665,11 +668,6 @@ class UniqueNullableNameNode(StructuredNode):
results, _ = db.cypher_query("MATCH (n:UniqueNullableNameNode) RETURN n")
assert len(results) == 3
- # Delete nodes afterwards
- x.delete()
- y.delete()
- z.delete()
-
def test_alias_property():
class AliasedClass(StructuredNode):
diff --git a/test/sync_/test_transactions.py b/test/sync_/test_transactions.py
index 834b538e..71ce479f 100644
--- a/test/sync_/test_transactions.py
+++ b/test/sync_/test_transactions.py
@@ -14,9 +14,6 @@ class APerson(StructuredNode):
@mark_sync_test
def test_rollback_and_commit_transaction():
- for p in APerson.nodes:
- p.delete()
-
APerson(name="Roger").save()
db.begin()
@@ -41,8 +38,6 @@ def in_a_tx(*names):
@mark_sync_test
def test_transaction_decorator():
db.install_labels(APerson)
- for p in APerson.nodes:
- p.delete()
# should work
in_a_tx("Roger")
@@ -68,9 +63,6 @@ def test_transaction_as_a_context():
@mark_sync_test
def test_query_inside_transaction():
- for p in APerson.nodes:
- p.delete()
-
with db.transaction:
APerson(name="Alice").save()
APerson(name="Bob").save()
@@ -119,9 +111,6 @@ def in_a_tx_with_bookmark(*names):
@mark_sync_test
def test_bookmark_transaction_decorator():
- for p in APerson.nodes:
- p.delete()
-
# should work
result, bookmarks = in_a_tx_with_bookmark("Ruth", bookmarks=None)
assert result is None
@@ -181,9 +170,6 @@ def test_bookmark_passed_in_to_context(spy_on_db_begin):
@mark_sync_test
def test_query_inside_bookmark_transaction():
- for p in APerson.nodes:
- p.delete()
-
with db.transaction as transaction:
APerson(name="Alice").save()
APerson(name="Bob").save()