From c73f5233a1b4185a038051761d7a566ab2720e98 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Wed, 2 Oct 2024 14:13:19 +0200 Subject: [PATCH] Added support for last() scalar function --- neomodel/async_/match.py | 21 ++++++++++++++++++--- neomodel/sync_/match.py | 21 ++++++++++++++++++--- test/async_/test_match_api.py | 6 ++++-- test/sync_/test_match_api.py | 14 +++++++++++--- 4 files changed, 51 insertions(+), 11 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index abe796e0..7048c86b 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -2,7 +2,7 @@ import re from collections import defaultdict from dataclasses import dataclass -from typing import Any, List, Optional +from typing import Any, List, Optional, Union from neomodel.async_.core import AsyncStructuredNode, adb from neomodel.async_.relationship import AsyncStructuredRel @@ -944,6 +944,21 @@ def __str__(self): return f"collect({self.input_name})" +@dataclass +class ScalarFunction: + """Base scalar function class.""" + + input_name: Union[str, AggregatingFunction] + + +@dataclass +class Last(ScalarFunction): + """last() function.""" + + def __str__(self) -> str: + return f"last({str(self.input_name)})" + + class AsyncNodeSet(AsyncBaseSet): """ A class representing as set of nodes matching common query parameters @@ -1167,8 +1182,8 @@ def traverse_relations(self, *relation_names, **aliased_relation_names): def annotate(self, *vars, **aliased_vars): """Annotate node set results with extra variables.""" - def register_extra_var(vardef, varname: str = None): - if isinstance(vardef, AggregatingFunction): + def register_extra_var(vardef, varname: Union[str, None] = None): + if isinstance(vardef, (AggregatingFunction, ScalarFunction)): self._extra_results[varname if varname else vardef.input_name] = vardef else: raise NotImplementedError diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 4ba70e2f..0bede175 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -2,7 +2,7 @@ import re from collections import defaultdict from dataclasses import dataclass -from typing import Any, List, Optional +from typing import Any, List, Optional, Union from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase @@ -942,6 +942,21 @@ def __str__(self): return f"collect({self.input_name})" +@dataclass +class ScalarFunction: + """Base scalar function class.""" + + input_name: Union[str, AggregatingFunction] + + +@dataclass +class Last(ScalarFunction): + """last() function.""" + + def __str__(self) -> str: + return f"last({str(self.input_name)})" + + class NodeSet(BaseSet): """ A class representing as set of nodes matching common query parameters @@ -1165,8 +1180,8 @@ def traverse_relations(self, *relation_names, **aliased_relation_names): def annotate(self, *vars, **aliased_vars): """Annotate node set results with extra variables.""" - def register_extra_var(vardef, varname: str = None): - if isinstance(vardef, AggregatingFunction): + def register_extra_var(vardef, varname: Union[str, None] = None): + if isinstance(vardef, (AggregatingFunction, ScalarFunction)): self._extra_results[varname if varname else vardef.input_name] = vardef else: raise NotImplementedError diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index c4ff24d9..043aef23 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -24,6 +24,7 @@ AsyncQueryBuilder, AsyncTraversal, Collect, + Last, Optional, ) from neomodel.exceptions import MultipleNodesReturned, RelationshipClassNotDefined @@ -732,13 +733,14 @@ async def test_subquery(): result = await Coffee.nodes.subquery( Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( - supps=Collect("suppliers") + supps=Last(Collect("suppliers")) ), ["supps"], ) result = await result.all() assert len(result) == 1 - assert len(result[0][0][0]) == 2 + assert len(result[0]) == 2 + assert result[0][0] == supplier1 with raises( RuntimeError, diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 09e19bc1..2ca283e9 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -20,7 +20,14 @@ ) from neomodel._async_compat.util import Util from neomodel.exceptions import MultipleNodesReturned, RelationshipClassNotDefined -from neomodel.sync_.match import Collect, NodeSet, Optional, QueryBuilder, Traversal +from neomodel.sync_.match import ( + Collect, + Last, + NodeSet, + Optional, + QueryBuilder, + Traversal, +) class SupplierRel(StructuredRel): @@ -720,13 +727,14 @@ def test_subquery(): result = Coffee.nodes.subquery( Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( - supps=Collect("suppliers") + supps=Last(Collect("suppliers")) ), ["supps"], ) result = result.all() assert len(result) == 1 - assert len(result[0][0][0]) == 2 + assert len(result[0]) == 2 + assert result[0][0] == supplier1 with raises( RuntimeError,