diff --git a/doc/source/advanced_query_operations.rst b/doc/source/advanced_query_operations.rst index 8cef8a1e..6a2b5466 100644 --- a/doc/source/advanced_query_operations.rst +++ b/doc/source/advanced_query_operations.rst @@ -16,7 +16,7 @@ Aggregations neomodel implements some of the aggregation methods available in Cypher: -- Collect +- Collect (with distinct option) - Last These are usable in this way:: @@ -33,6 +33,8 @@ These are usable in this way:: .annotate(Last(Collect("last_species"))) .all() +Note how `annotate` is used to add the aggregation method to the query. + .. note:: Using the Last() method right after a Collect() without having set an ordering will return the last element in the list as it was returned by the database. diff --git a/doc/source/traversal.rst b/doc/source/traversal.rst index 4cbb2fd4..e4d94b34 100644 --- a/doc/source/traversal.rst +++ b/doc/source/traversal.rst @@ -78,7 +78,7 @@ With both `traverse_relations` and `fetch_relations`, you can force the use of a Resolve results --------------- -By default, fetch_relations will return a list of tuples. If your path looks like ``(startNode:Coffee)<-[r1]-(middleNode:Supplier)-[r2]->(endNode:Country)``, +By default, `fetch_relations` will return a list of tuples. If your path looks like ``(startNode:Coffee)<-[r1]-(middleNode:Supplier)-[r2]->(endNode:Country)``, then you will get a list of results, where each result is a list of ``(startNode, r1, middleNode, r2, endNode)``. These will be resolved by neomodel, so ``startNode`` will be a ``Coffee`` class as defined in neomodel for example. @@ -88,3 +88,7 @@ Using the `resolve_subgraph` method, you can get instead a list of "subgraphs", In this example, `results[0]` will be a `Coffee` object, with a `_relations` attribute. This will in turn have a `suppliers` and a `suppliers_relationship` attribute, which will contain the `Supplier` object and the relation object respectively. Recursively, the `Supplier` object will have a `country` attribute, which will contain the `Country` object. +.. note:: + + The `resolve_subgraph` method is only available for `fetch_relations` queries. This is because `traverse_relations` queries do not return any relations, and thus there is no need to resolve them. + diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 81944871..ab4bb9cf 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -91,6 +91,30 @@ class SoftwareDependency(AsyncStructuredNode): version = StringProperty(required=True) +class HasCourseRel(AsyncStructuredRel): + level = StringProperty() + start_date = DateTimeProperty() + end_date = DateTimeProperty() + + +class Course(AsyncStructuredNode): + name = StringProperty() + + +class Building(AsyncStructuredNode): + name = StringProperty() + + +class Student(AsyncStructuredNode): + name = StringProperty() + + parents = AsyncRelationshipTo("Student", "HAS_PARENT") + children = AsyncRelationshipFrom("Student", "HAS_PARENT") + lives_in = AsyncRelationshipTo(Building, "LIVES_IN") + has_course = AsyncRelationshipTo(Course, "HAS_COURSE", model=HasCourseRel) + has_latest_course = AsyncRelationshipTo(Course, "HAS_COURSE") + + @mark_async_test async def test_filter_exclude_via_labels(): await Coffee(name="Java", price=99).save() @@ -557,7 +581,7 @@ async def test_traversal_filter_left_hand_statement(): nescafe = await Coffee(name="Nescafe2", price=99).save() nescafe_gold = await Coffee(name="Nescafe gold", price=11).save() - tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = await Supplier(name="Tesco", delivery_cost=3).save() biedronka = await Supplier(name="Biedronka", delivery_cost=5).save() lidl = await Supplier(name="Lidl", delivery_cost=3).save() @@ -583,7 +607,7 @@ async def test_filter_with_traversal(): robusta = await Species(name="Robusta").save() nescafe = await Coffee(name="Nescafe", price=11).save() nescafe_gold = await Coffee(name="Nescafe Gold", price=99).save() - tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = await Supplier(name="Tesco", delivery_cost=3).save() await nescafe.suppliers.connect(tesco) await nescafe_gold.suppliers.connect(tesco) await nescafe.species.connect(arabica) @@ -594,6 +618,15 @@ async def test_filter_with_traversal(): assert len(results[0]) == 3 assert results[0][0] == nescafe + results_multi_hop = await Supplier.nodes.filter( + coffees__species__name="Arabica" + ).all() + assert len(results_multi_hop) == 1 + assert results_multi_hop[0][0] == tesco + + no_results = await Supplier.nodes.filter(coffees__species__name="Noffee").all() + assert no_results == [] + @mark_async_test async def test_relation_prop_filtering(): @@ -616,6 +649,16 @@ async def test_relation_prop_filtering(): assert len(results) == 1 assert results[0][0] == supplier1 + # Test it works with mixed argument syntaxes + results2 = await Supplier.nodes.filter( + name="Supplier 1", + coffees__name="Nescafe", + **{"coffees|since__gt": datetime(2018, 4, 1, 0, 0)}, + ).all() + + assert len(results2) == 1 + assert results2[0][0] == supplier1 + @mark_async_test async def test_relation_prop_ordering(): @@ -656,7 +699,7 @@ async def test_fetch_relations(): nescafe = await Coffee(name="Nescafe", price=99).save() nescafe_gold = await Coffee(name="Nescafe Gold", price=11).save() - tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = await Supplier(name="Tesco", delivery_cost=3).save() await nescafe.suppliers.connect(tesco) await nescafe_gold.suppliers.connect(tesco) await nescafe.species.connect(arabica) @@ -715,7 +758,7 @@ async def test_traverse_and_order_by(): robusta = await Species(name="Robusta").save() nescafe = await Coffee(name="Nescafe", price=99).save() nescafe_gold = await Coffee(name="Nescafe Gold", price=110).save() - tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = await Supplier(name="Tesco", delivery_cost=3).save() await nescafe.suppliers.connect(tesco) await nescafe_gold.suppliers.connect(tesco) await nescafe.species.connect(arabica) @@ -740,7 +783,7 @@ async def test_annotate_and_collect(): nescafe = await Coffee(name="Nescafe 1002", price=99).save() nescafe_gold = await Coffee(name="Nescafe 1003", price=11).save() - tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = await Supplier(name="Tesco", delivery_cost=3).save() await nescafe.suppliers.connect(tesco) await nescafe_gold.suppliers.connect(tesco) await nescafe.species.connect(arabica) @@ -793,7 +836,7 @@ async def test_resolve_subgraph(): nescafe = await Coffee(name="Nescafe", price=99).save() nescafe_gold = await Coffee(name="Nescafe Gold", price=11).save() - tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = await Supplier(name="Tesco", delivery_cost=3).save() await nescafe.suppliers.connect(tesco) await nescafe_gold.suppliers.connect(tesco) await nescafe.species.connect(arabica) @@ -842,7 +885,7 @@ async def test_resolve_subgraph_optional(): nescafe = await Coffee(name="Nescafe", price=99).save() nescafe_gold = await Coffee(name="Nescafe Gold", price=11).save() - tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = await Supplier(name="Tesco", delivery_cost=3).save() await nescafe.suppliers.connect(tesco) await nescafe_gold.suppliers.connect(tesco) await nescafe.species.connect(arabica) @@ -952,6 +995,89 @@ async def test_intermediate_transform(): ) +@mark_async_test +async def test_mix_functions(): + # Test with a mix of all advanced querying functions + + eiffel_tower = await Building(name="Eiffel Tower").save() + empire_state_building = await Building(name="Empire State Building").save() + miranda = await Student(name="Miranda").save() + await miranda.lives_in.connect(empire_state_building) + jean_pierre = await Student(name="Jean-Pierre").save() + await jean_pierre.lives_in.connect(eiffel_tower) + mireille = await Student(name="Mireille").save() + mimoun_jr = await Student(name="Mimoun Jr").save() + mimoun = await Student(name="Mimoun").save() + await mireille.lives_in.connect(eiffel_tower) + await mimoun_jr.lives_in.connect(eiffel_tower) + await mimoun.lives_in.connect(eiffel_tower) + await mimoun.parents.connect(mireille) + await mimoun.children.connect(mimoun_jr) + course = await Course(name="Math").save() + await mimoun.has_course.connect( + course, + { + "level": "1.2", + "start_date": datetime(2020, 6, 2), + "end_date": datetime(2020, 12, 31), + }, + ) + await mimoun.has_course.connect( + course, + { + "level": "1.1", + "start_date": datetime(2020, 1, 1), + "end_date": datetime(2020, 6, 1), + }, + ) + await mimoun_jr.has_course.connect( + course, + { + "level": "1.1", + "start_date": datetime(2020, 1, 1), + "end_date": datetime(2020, 6, 1), + }, + ) + + filtered_nodeset = Student.nodes.filter( + name__istartswith="m", lives_in__name="Eiffel Tower" + ) + full_nodeset = ( + await filtered_nodeset.order_by("name") + .traverse_relations( + "parents", + ) + .fetch_relations( + "lives_in", + Optional("children__has_latest_course"), + ) + .subquery( + filtered_nodeset.order_by("name") + .fetch_relations("has_course") + .intermediate_transform( + {"rel": RelationNameResolver("has_course")}, + ordering=[ + RawCypher("toInteger(split(rel.level, '.')[0])"), + RawCypher("toInteger(split(rel.level, '.')[1])"), + "rel.end_date", + "rel.start_date", + ], + ) + .annotate( + latest_course=Last(Collect("rel")), + ), + ["latest_course"], + ) + ) + + subgraph = await full_nodeset.annotate( + Collect(NodeNameResolver("children"), distinct=True), + Collect(NodeNameResolver("children__has_latest_course"), distinct=True), + ).resolve_subgraph() + + print(subgraph) + + @mark_async_test async def test_issue_795(): jim = await PersonX(name="Jim", age=3).save() # Create diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index e47e3396..139ad3b4 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -89,6 +89,30 @@ class SoftwareDependency(StructuredNode): version = StringProperty(required=True) +class HasCourseRel(StructuredRel): + level = StringProperty() + start_date = DateTimeProperty() + end_date = DateTimeProperty() + + +class Course(StructuredNode): + name = StringProperty() + + +class Building(StructuredNode): + name = StringProperty() + + +class Student(StructuredNode): + name = StringProperty() + + parents = RelationshipTo("Student", "HAS_PARENT") + children = RelationshipFrom("Student", "HAS_PARENT") + lives_in = RelationshipTo(Building, "LIVES_IN") + has_course = RelationshipTo(Course, "HAS_COURSE", model=HasCourseRel) + has_latest_course = RelationshipTo(Course, "HAS_COURSE") + + @mark_sync_test def test_filter_exclude_via_labels(): Coffee(name="Java", price=99).save() @@ -553,7 +577,7 @@ def test_traversal_filter_left_hand_statement(): nescafe = Coffee(name="Nescafe2", price=99).save() nescafe_gold = Coffee(name="Nescafe gold", price=11).save() - tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = Supplier(name="Tesco", delivery_cost=3).save() biedronka = Supplier(name="Biedronka", delivery_cost=5).save() lidl = Supplier(name="Lidl", delivery_cost=3).save() @@ -577,7 +601,7 @@ def test_filter_with_traversal(): robusta = Species(name="Robusta").save() nescafe = Coffee(name="Nescafe", price=11).save() nescafe_gold = Coffee(name="Nescafe Gold", price=99).save() - tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = Supplier(name="Tesco", delivery_cost=3).save() nescafe.suppliers.connect(tesco) nescafe_gold.suppliers.connect(tesco) nescafe.species.connect(arabica) @@ -588,6 +612,13 @@ def test_filter_with_traversal(): assert len(results[0]) == 3 assert results[0][0] == nescafe + results_multi_hop = Supplier.nodes.filter(coffees__species__name="Arabica").all() + assert len(results_multi_hop) == 1 + assert results_multi_hop[0][0] == tesco + + no_results = Supplier.nodes.filter(coffees__species__name="Noffee").all() + assert no_results == [] + @mark_sync_test def test_relation_prop_filtering(): @@ -610,6 +641,16 @@ def test_relation_prop_filtering(): assert len(results) == 1 assert results[0][0] == supplier1 + # Test it works with mixed argument syntaxes + results2 = Supplier.nodes.filter( + name="Supplier 1", + coffees__name="Nescafe", + **{"coffees|since__gt": datetime(2018, 4, 1, 0, 0)}, + ).all() + + assert len(results2) == 1 + assert results2[0][0] == supplier1 + @mark_sync_test def test_relation_prop_ordering(): @@ -646,7 +687,7 @@ def test_fetch_relations(): nescafe = Coffee(name="Nescafe", price=99).save() nescafe_gold = Coffee(name="Nescafe Gold", price=11).save() - tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = Supplier(name="Tesco", delivery_cost=3).save() nescafe.suppliers.connect(tesco) nescafe_gold.suppliers.connect(tesco) nescafe.species.connect(arabica) @@ -705,7 +746,7 @@ def test_traverse_and_order_by(): robusta = Species(name="Robusta").save() nescafe = Coffee(name="Nescafe", price=99).save() nescafe_gold = Coffee(name="Nescafe Gold", price=110).save() - tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = Supplier(name="Tesco", delivery_cost=3).save() nescafe.suppliers.connect(tesco) nescafe_gold.suppliers.connect(tesco) nescafe.species.connect(arabica) @@ -728,7 +769,7 @@ def test_annotate_and_collect(): nescafe = Coffee(name="Nescafe 1002", price=99).save() nescafe_gold = Coffee(name="Nescafe 1003", price=11).save() - tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = Supplier(name="Tesco", delivery_cost=3).save() nescafe.suppliers.connect(tesco) nescafe_gold.suppliers.connect(tesco) nescafe.species.connect(arabica) @@ -781,7 +822,7 @@ def test_resolve_subgraph(): nescafe = Coffee(name="Nescafe", price=99).save() nescafe_gold = Coffee(name="Nescafe Gold", price=11).save() - tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = Supplier(name="Tesco", delivery_cost=3).save() nescafe.suppliers.connect(tesco) nescafe_gold.suppliers.connect(tesco) nescafe.species.connect(arabica) @@ -830,7 +871,7 @@ def test_resolve_subgraph_optional(): nescafe = Coffee(name="Nescafe", price=99).save() nescafe_gold = Coffee(name="Nescafe Gold", price=11).save() - tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = Supplier(name="Tesco", delivery_cost=3).save() nescafe.suppliers.connect(tesco) nescafe_gold.suppliers.connect(tesco) nescafe.species.connect(arabica) @@ -940,6 +981,89 @@ def test_intermediate_transform(): ) +@mark_sync_test +def test_mix_functions(): + # Test with a mix of all advanced querying functions + + eiffel_tower = Building(name="Eiffel Tower").save() + empire_state_building = Building(name="Empire State Building").save() + miranda = Student(name="Miranda").save() + miranda.lives_in.connect(empire_state_building) + jean_pierre = Student(name="Jean-Pierre").save() + jean_pierre.lives_in.connect(eiffel_tower) + mireille = Student(name="Mireille").save() + mimoun_jr = Student(name="Mimoun Jr").save() + mimoun = Student(name="Mimoun").save() + mireille.lives_in.connect(eiffel_tower) + mimoun_jr.lives_in.connect(eiffel_tower) + mimoun.lives_in.connect(eiffel_tower) + mimoun.parents.connect(mireille) + mimoun.children.connect(mimoun_jr) + course = Course(name="Math").save() + mimoun.has_course.connect( + course, + { + "level": "1.2", + "start_date": datetime(2020, 6, 2), + "end_date": datetime(2020, 12, 31), + }, + ) + mimoun.has_course.connect( + course, + { + "level": "1.1", + "start_date": datetime(2020, 1, 1), + "end_date": datetime(2020, 6, 1), + }, + ) + mimoun_jr.has_course.connect( + course, + { + "level": "1.1", + "start_date": datetime(2020, 1, 1), + "end_date": datetime(2020, 6, 1), + }, + ) + + filtered_nodeset = Student.nodes.filter( + name__istartswith="m", lives_in__name="Eiffel Tower" + ) + full_nodeset = ( + filtered_nodeset.order_by("name") + .traverse_relations( + "parents", + ) + .fetch_relations( + "lives_in", + Optional("children__has_latest_course"), + ) + .subquery( + filtered_nodeset.order_by("name") + .fetch_relations("has_course") + .intermediate_transform( + {"rel": RelationNameResolver("has_course")}, + ordering=[ + RawCypher("toInteger(split(rel.level, '.')[0])"), + RawCypher("toInteger(split(rel.level, '.')[1])"), + "rel.end_date", + "rel.start_date", + ], + ) + .annotate( + latest_course=Last(Collect("rel")), + ), + ["latest_course"], + ) + ) + + subgraph = full_nodeset.annotate( + Collect(NodeNameResolver("children"), distinct=True), + Collect(NodeNameResolver("children__has_latest_course"), distinct=True), + ).resolve_subgraph() + + print(subgraph) + + @mark_sync_test def test_issue_795(): jim = PersonX(name="Jim", age=3).save() # Create