diff --git a/grai-server/app/api/pagination.py b/grai-server/app/api/pagination.py index 5f3cf4dd9..da776c71f 100644 --- a/grai-server/app/api/pagination.py +++ b/grai-server/app/api/pagination.py @@ -1,9 +1,11 @@ -from typing import Callable, Generic, List, Optional, TypeVar +from typing import Any, Callable, Generic, List, Optional, TypeVar, Union import strawberry from django.db.models.query import QuerySet from strawberry.field import StrawberryField from strawberry_django.pagination import OffsetPaginationInput +from django.db.models.query import QuerySet +from django.db.models import Model from .order import apply_order @@ -25,7 +27,7 @@ def __init__( self, queryset: QuerySet, filteredQueryset: Optional[QuerySet] = None, - apply_filters: Callable[[QuerySet], QuerySet] = None, + apply_filters: Union[Callable[[QuerySet], QuerySet], None] = None, order: Optional[StrawberryField] = strawberry.UNSET, pagination: Optional[OffsetPaginationInput] = strawberry.UNSET, ): @@ -66,11 +68,14 @@ def apply_pagination(queryset: QuerySet, pagination: Optional[OffsetPaginationIn return queryset +S = TypeVar("S") + + @strawberry.type -class DataWrapper(Generic[T]): - def __init__(self, data: List[T]): +class DataWrapper(Generic[S]): + def __init__(self, data: List[S]): self.data = data @strawberry.django.field - def data(self) -> List[T]: + def data(self) -> List[S]: return self.data diff --git a/grai-server/app/api/queries.py b/grai-server/app/api/queries.py index ed3e12a03..de118a064 100644 --- a/grai-server/app/api/queries.py +++ b/grai-server/app/api/queries.py @@ -26,7 +26,16 @@ def get_workspace( user = get_user(info) try: - query = {"id": id} if id else {"name": name, "organisation__name": organisationName} + query = None + + if id: + query = {"id": id} + elif name and organisationName: + query = {"name": name, "organisation__name": organisationName} + + if not query: + raise Exception("Can't find workspace") + workspace = WorkspaceModel.objects.get(**query, memberships__user_id=user.id) except WorkspaceModel.DoesNotExist: raise Exception("Can't find workspace") diff --git a/grai-server/app/api/tests/test_queries.py b/grai-server/app/api/tests/test_queries.py index 13cc744d2..49e61bda3 100644 --- a/grai-server/app/api/tests/test_queries.py +++ b/grai-server/app/api/tests/test_queries.py @@ -70,6 +70,61 @@ async def test_workspace_get(test_context): } +@pytest.mark.django_db +@pytest.mark.asyncio +async def test_workspace_name(test_context): + context, organisation, workspace, user, membership = test_context + + query = """ + query Workspace($name: String!, $organisationName: String!) { + workspace(name: $name, organisationName: $organisationName) { + id + name + } + } + """ + + result = await schema.execute( + query, + variable_values={ + "name": workspace.name, + "organisationName": organisation.name, + }, + context_value=context, + ) + + assert result.errors is None + assert result.data["workspace"] == { + "id": str(workspace.id), + "name": workspace.name, + } + + +@pytest.mark.django_db +@pytest.mark.asyncio +async def test_workspace_none(test_context): + context, organisation, workspace, user, membership = test_context + + query = """ + query Workspace { + workspace { + id + name + } + } + """ + + result = await schema.execute( + query, + context_value=context, + ) + + assert ( + str(result.errors) + == """[GraphQLError("Can't find workspace", locations=[SourceLocation(line=3, column=13)], path=['workspace'])]""" + ) + + @pytest.mark.django_db @pytest.mark.asyncio async def test_workspace_no_workspace(test_context): diff --git a/grai-server/app/api/types.py b/grai-server/app/api/types.py index 0c1116463..96bdaaf48 100755 --- a/grai-server/app/api/types.py +++ b/grai-server/app/api/types.py @@ -2,7 +2,7 @@ import time from collections import defaultdict from enum import Enum -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union import strawberry import strawberry_django @@ -13,7 +13,6 @@ from grai_graph.graph import BaseSourceSegment from notifications.models import Alert as AlertModel from strawberry.scalars import JSON -from strawberry.types import Info from strawberry_django.filters import FilterLookup from strawberry_django.pagination import OffsetPaginationInput @@ -892,15 +891,22 @@ async def graph( query.match("(table:Table)") if filters and filters.source_id: - return graph.get_source_filtered_graph_result(filters.source_id, filters.n) + return graph.get_source_filtered_graph_result(filters.source_id, filters.n or 0) if filters and filters.table_id: - return graph.get_table_filtered_graph_result(filters.table_id, filters.n) + return graph.get_table_filtered_graph_result(filters.table_id, filters.n or 0) if filters and filters.edge_id: - return graph.get_edge_filtered_graph_result(filters.edge_id, filters.n) - - if filters and filters.min_x is not None and filters.max_x is not strawberry.UNSET: + return graph.get_edge_filtered_graph_result(filters.edge_id, filters.n or 0) + + if ( + filters + and filters.min_x is not None + and filters.min_x is not strawberry.UNSET + and filters.max_x is not None + and filters.min_y is not None + and filters.max_y is not None + ): graph.filter_by_range(filters.min_x, filters.max_x, filters.min_y, filters.max_y, query) if filters and filters.filters: @@ -962,7 +968,7 @@ def source(self, id: strawberry.ID) -> Source: # Source Graph @strawberry.field def source_graph(self) -> List[SourceGraph]: - def fetch_source_graph(workspace: WorkspaceModel): + def fetch_source_graph(workspace: Workspace): nodes = defaultdict(list) for node in ( NodeModel.objects.filter(workspace=workspace, is_active=True) diff --git a/grai-server/app/lineage/admin.py b/grai-server/app/lineage/admin.py index a32604d10..4a9c18baa 100755 --- a/grai-server/app/lineage/admin.py +++ b/grai-server/app/lineage/admin.py @@ -171,7 +171,7 @@ def view(self): # pragma: no cover ) fields = ("name", view) - readonly_fields = (view,) + readonly_fields = [view] class SourceAdmin(admin.ModelAdmin): diff --git a/grai-server/app/lineage/graph.py b/grai-server/app/lineage/graph.py index 548e3109c..d0c3705a5 100644 --- a/grai-server/app/lineage/graph.py +++ b/grai-server/app/lineage/graph.py @@ -19,12 +19,23 @@ def __str__(self) -> str: return self.where +WhereType = Union[str, Where] +WhereArrayType = Union[WhereType, List[WhereType]] + + +def wrapWhere(input: WhereArrayType) -> List[Where]: + if isinstance(input, list): + return [Where(w) if isinstance(w, str) else w for w in input] + + return [input if isinstance(input, Where) else Where(input)] + + class Match: def __init__( self, match: str, optional: bool = False, - where: Union[str, Where, List[Where]] = None, + where: WhereArrayType | None = None, parameters: object = None, ): self.match = match @@ -33,16 +44,16 @@ def __init__( if isinstance(where, str): where = Where(where) - self.wheres = wrap(where) if where else [] + self.wheres = wrapWhere(where) if where else [] self.parameters = parameters if parameters else {} - def where(self, where: Union[str, Where, List[Where]]) -> "Match": + def where(self, where: WhereArrayType) -> "Match": if isinstance(where, str): self.wheres.append(Where(where)) return self - self.wheres.extend(wrap(where)) + self.wheres.extend(wrapWhere(where)) return self @@ -63,63 +74,54 @@ def get_parameters(self): return res -Clause = Union[Match, str] +MatchType = Union[Match, str] +MatchTypeArray = Union[MatchType, List[MatchType]] + + +def wrapMatch(input: MatchTypeArray, optional: bool = False, where: WhereArrayType | None = None) -> List[Match]: + if isinstance(input, list): + return [Match(w, optional=optional, where=where) if isinstance(w, str) else w for w in input] + + return [input if isinstance(input, Match) else Match(input, optional=optional, where=where)] class GraphQuery: - def __init__(self, clause: Union[Clause, List[Clause]] = None, parameters: object = None): + def __init__( + self, + clause: MatchTypeArray | None = None, + parameters: object = None, + ): self.clause = wrap(clause) if clause else [] self.parameters = parameters if parameters else {} - self.withWheres = None + self.withWheres: str | None = None def match( self, - match: Union[str, Match, List[Match]], - where: Union[str, Where, List[Where]] = [], + match: MatchTypeArray, + where: WhereArrayType | None = None, parameters: object = {}, ) -> "GraphQuery": self.parameters = self.parameters | parameters - if isinstance(match, Match): - self.clause.append(match) - - return self - - if isinstance(match, List): - self.clause.extend(match) - - return self - - self.clause.append(Match(match, where=where)) + matches = wrapMatch(match, where=where) + self.clause.extend(matches) return self def optional_match( self, - match: Union[str, Match, List[Match]], - where: Union[str, Where, List[Where]] = [], + match: MatchTypeArray, + where: WhereArrayType | None = None, parameters: object = {}, ) -> "GraphQuery": self.parameters = self.parameters | parameters - if isinstance(match, Match): - match.optional = True - self.clause.append(match) - - return self - - if isinstance(match, List): - for m in match: - m.optional = True - self.clause.extend(match) - - return self - - self.clause.append(Match(match, optional=True, where=where)) + matches = wrapMatch(match, optional=True, where=where) + self.clause.extend(matches) return self - def where(self, where: Union[str, Where, List[Where]], parameters: object = {}) -> "GraphQuery": + def where(self, where: WhereArrayType, parameters: object = {}) -> "GraphQuery": if isinstance(where, str): where = Where(where, parameters) diff --git a/grai-server/app/lineage/graph_cache.py b/grai-server/app/lineage/graph_cache.py index 778a2e622..1bf0fb5ad 100644 --- a/grai-server/app/lineage/graph_cache.py +++ b/grai-server/app/lineage/graph_cache.py @@ -29,7 +29,7 @@ def __init__(self, workspace: Union[Workspace, str]): db=0, ) - def query(self, query: str, parameters: object = {}, timeout: int = None): + def query(self, query: str, parameters: object = {}, timeout: Optional[int] = None): try: return self.manager.graph(f"lineage:{str(self.workspace_id)}").query(query, parameters, timeout=timeout) except redis.exceptions.ResponseError as e: @@ -373,10 +373,14 @@ def filter_by_filters(self, filters, query: GraphQuery) -> GraphQuery: for filter in filters: query = filter_by_filter(filter, query) + return query + def filter_by_rows(self, filters, query: GraphQuery) -> GraphQuery: for filter in filters: query = filter_by_dict(filter, query) + return query + def get_with_step_graph_result( self, n: int, parameters: object = {}, where: Optional[str] = None ) -> List["GraphTable"]: diff --git a/grai-server/app/lineage/managers.py b/grai-server/app/lineage/managers.py index 5b1488df1..3bf52d71f 100644 --- a/grai-server/app/lineage/managers.py +++ b/grai-server/app/lineage/managers.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Sequence +from typing import Any, Iterable, List, Sequence from django.db import models from django_multitenant.models import TenantManagerMixin @@ -10,13 +10,13 @@ class CacheManager(TenantManagerMixin, models.Manager): def bulk_create( self, - objs: Iterable, - batch_size: int = None, + objs: Iterable[Any], + batch_size: int | None = None, **kwargs, ) -> List: result = super().bulk_create(objs, **kwargs) - if len(objs) > 0: + if len(list(objs)) > 0: workspace = objs[0].workspace cache = GraphCache(workspace) @@ -29,7 +29,7 @@ def bulk_create( def bulk_update( self, - objs: Iterable, + objs: Iterable[Any], fields: Sequence[str], **kwargs, ) -> int: @@ -39,7 +39,7 @@ def bulk_update( **kwargs, ) - if len(objs) > 0: + if len(list(objs)) > 0: workspace = objs[0].workspace cache = GraphCache(workspace) diff --git a/grai-server/app/lineage/models.py b/grai-server/app/lineage/models.py index 22ee9f177..52b5a27d2 100755 --- a/grai-server/app/lineage/models.py +++ b/grai-server/app/lineage/models.py @@ -69,7 +69,7 @@ def set_names(self, *args, **kwargs): self.display_name = self.name return self - def cache_model(self, cache: GraphCache = None, delete: bool = False): + def cache_model(self, cache: GraphCache | None = None, delete: bool = False): if cache: cache.cache_node(self) else: diff --git a/grai-server/app/lineage/serializers.py b/grai-server/app/lineage/serializers.py index 20b2a2622..95b006c51 100755 --- a/grai-server/app/lineage/serializers.py +++ b/grai-server/app/lineage/serializers.py @@ -101,7 +101,10 @@ def to_internal_value(self, data): match data: case { "source": {"name": source_name, "namespace": source_namespace}, - "destination": {"name": destination_name, "namespace": destination_namespace}, + "destination": { + "name": destination_name, + "namespace": destination_namespace, + }, }: q_filter = Q(name=source_name) & Q(namespace=source_namespace) q_filter |= Q(name=destination_name) & Q(namespace=destination_namespace) @@ -143,7 +146,7 @@ class Meta: read_only_fields = ("created_at", "updated_at") -class SourceChildMixin: +class SourceChildMixin(serializers.ModelSerializer): @cached_property def source_model(self) -> Source: return Source.objects.get(pk=self.context["view"].kwargs["source_pk"]) diff --git a/grai-server/app/lineage/tests/test_graph.py b/grai-server/app/lineage/tests/test_graph.py index 4b41cc78c..a7cd148c8 100644 --- a/grai-server/app/lineage/tests/test_graph.py +++ b/grai-server/app/lineage/tests/test_graph.py @@ -218,7 +218,7 @@ def test_where_no_match(self): def test_add(self): query = GraphQuery(clause="(a)").add("abc") - assert query.clause[1] == "abc" + assert str(query.clause[1]) == "abc" assert query.parameters == {} assert str(query) == "(a) abc" diff --git a/grai-server/app/mypy.ini b/grai-server/app/mypy.ini new file mode 100644 index 000000000..bf22d44dd --- /dev/null +++ b/grai-server/app/mypy.ini @@ -0,0 +1,3 @@ +[mypy] +exclude = tests|scripts +ignore_missing_imports = True diff --git a/grai-server/app/poetry.lock b/grai-server/app/poetry.lock index eb202a484..fdef133ce 100644 --- a/grai-server/app/poetry.lock +++ b/grai-server/app/poetry.lock @@ -4414,6 +4414,20 @@ dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2 doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)"] test = ["black (>=22.3.0,<23.0.0)", "coverage (>=5.2,<6.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<5.4.0)", "pytest-cov (>=2.10.0,<3.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<2.0.0)", "rich (>=10.11.0,<13.0.0)", "shellingham (>=1.3.0,<2.0.0)"] +[[package]] +name = "types-pyopenssl" +version = "23.2.0.2" +description = "Typing stubs for pyOpenSSL" +optional = false +python-versions = "*" +files = [ + {file = "types-pyOpenSSL-23.2.0.2.tar.gz", hash = "sha256:6a010dac9ecd42b582d7dd2cc3e9e40486b79b3b64bb2fffba1474ff96af906d"}, + {file = "types_pyOpenSSL-23.2.0.2-py3-none-any.whl", hash = "sha256:19536aa3debfbe25a918cf0d898e9f5fbbe6f3594a429da7914bf331deb1b342"}, +] + +[package.dependencies] +cryptography = ">=35.0.0" + [[package]] name = "types-pytz" version = "2023.3.0.1" @@ -4436,6 +4450,21 @@ files = [ {file = "types_PyYAML-6.0.12.11-py3-none-any.whl", hash = "sha256:a461508f3096d1d5810ec5ab95d7eeecb651f3a15b71959999988942063bf01d"}, ] +[[package]] +name = "types-redis" +version = "4.6.0.5" +description = "Typing stubs for redis" +optional = false +python-versions = "*" +files = [ + {file = "types-redis-4.6.0.5.tar.gz", hash = "sha256:5f179d10bd3ca995a8134aafcddfc3e12d52b208437c4529ef27e68acb301f38"}, + {file = "types_redis-4.6.0.5-py3-none-any.whl", hash = "sha256:4f662060247a2363c7a8f0b7e52915d68960870ff16a749a891eabcf87ed0be4"}, +] + +[package.dependencies] +cryptography = ">=35.0.0" +types-pyOpenSSL = "*" + [[package]] name = "types-requests" version = "2.31.0.2" @@ -4552,4 +4581,4 @@ brotli = ["Brotli"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "6feaca99bd70d809e0fdbd02e334cf1ddb95da7ced6ead3bbfea8ce14ce6906b" +content-hash = "b4c5ff97d888feb835af9236df1450fb0a0025597355b5574f3a019fd8a5e52a" diff --git a/grai-server/app/pyproject.toml b/grai-server/app/pyproject.toml index b1b2bc2f2..790061816 100644 --- a/grai-server/app/pyproject.toml +++ b/grai-server/app/pyproject.toml @@ -65,6 +65,7 @@ grandalf = "^0.8" drf-nested-routers = "^0.93.4" django-redis = "^5.3.0" retakesearch = "^0.1.32" +types-redis = "^4.6.0.5" [tool.poetry.group.dev.dependencies] isort = "^5.10.1"