Skip to content

Commit

Permalink
Added support for last() scalar function
Browse files Browse the repository at this point in the history
  • Loading branch information
tonioo committed Oct 2, 2024
1 parent 56f0ff0 commit c73f523
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 11 deletions.
21 changes: 18 additions & 3 deletions neomodel/async_/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from neomodel.async_.core import AsyncStructuredNode, adb
from neomodel.async_.relationship import AsyncStructuredRel
Expand Down Expand Up @@ -944,6 +944,21 @@ def __str__(self):
return f"collect({self.input_name})"


@dataclass
class ScalarFunction:
"""Base scalar function class."""

input_name: Union[str, AggregatingFunction]


@dataclass
class Last(ScalarFunction):
"""last() function."""

def __str__(self) -> str:
return f"last({str(self.input_name)})"


class AsyncNodeSet(AsyncBaseSet):
"""
A class representing as set of nodes matching common query parameters
Expand Down Expand Up @@ -1167,8 +1182,8 @@ def traverse_relations(self, *relation_names, **aliased_relation_names):
def annotate(self, *vars, **aliased_vars):
"""Annotate node set results with extra variables."""

def register_extra_var(vardef, varname: str = None):
if isinstance(vardef, AggregatingFunction):
def register_extra_var(vardef, varname: Union[str, None] = None):
if isinstance(vardef, (AggregatingFunction, ScalarFunction)):
self._extra_results[varname if varname else vardef.input_name] = vardef
else:
raise NotImplementedError
Expand Down
21 changes: 18 additions & 3 deletions neomodel/sync_/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from neomodel.exceptions import MultipleNodesReturned
from neomodel.match_q import Q, QBase
Expand Down Expand Up @@ -942,6 +942,21 @@ def __str__(self):
return f"collect({self.input_name})"


@dataclass
class ScalarFunction:
"""Base scalar function class."""

input_name: Union[str, AggregatingFunction]


@dataclass
class Last(ScalarFunction):
"""last() function."""

def __str__(self) -> str:
return f"last({str(self.input_name)})"


class NodeSet(BaseSet):
"""
A class representing as set of nodes matching common query parameters
Expand Down Expand Up @@ -1165,8 +1180,8 @@ def traverse_relations(self, *relation_names, **aliased_relation_names):
def annotate(self, *vars, **aliased_vars):
"""Annotate node set results with extra variables."""

def register_extra_var(vardef, varname: str = None):
if isinstance(vardef, AggregatingFunction):
def register_extra_var(vardef, varname: Union[str, None] = None):
if isinstance(vardef, (AggregatingFunction, ScalarFunction)):
self._extra_results[varname if varname else vardef.input_name] = vardef
else:
raise NotImplementedError
Expand Down
6 changes: 4 additions & 2 deletions test/async_/test_match_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
AsyncQueryBuilder,
AsyncTraversal,
Collect,
Last,
Optional,
)
from neomodel.exceptions import MultipleNodesReturned, RelationshipClassNotDefined
Expand Down Expand Up @@ -732,13 +733,14 @@ async def test_subquery():

result = await Coffee.nodes.subquery(
Coffee.nodes.traverse_relations(suppliers="suppliers").annotate(
supps=Collect("suppliers")
supps=Last(Collect("suppliers"))
),
["supps"],
)
result = await result.all()
assert len(result) == 1
assert len(result[0][0][0]) == 2
assert len(result[0]) == 2
assert result[0][0] == supplier1

with raises(
RuntimeError,
Expand Down
14 changes: 11 additions & 3 deletions test/sync_/test_match_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
)
from neomodel._async_compat.util import Util
from neomodel.exceptions import MultipleNodesReturned, RelationshipClassNotDefined
from neomodel.sync_.match import Collect, NodeSet, Optional, QueryBuilder, Traversal
from neomodel.sync_.match import (
Collect,
Last,
NodeSet,
Optional,
QueryBuilder,
Traversal,
)


class SupplierRel(StructuredRel):
Expand Down Expand Up @@ -720,13 +727,14 @@ def test_subquery():

result = Coffee.nodes.subquery(
Coffee.nodes.traverse_relations(suppliers="suppliers").annotate(
supps=Collect("suppliers")
supps=Last(Collect("suppliers"))
),
["supps"],
)
result = result.all()
assert len(result) == 1
assert len(result[0][0][0]) == 2
assert len(result[0]) == 2
assert result[0][0] == supplier1

with raises(
RuntimeError,
Expand Down

0 comments on commit c73f523

Please sign in to comment.