Skip to content

Commit

Permalink
type updates (#636)
Browse files Browse the repository at this point in the history
* type updates

* Add tests
  • Loading branch information
edlouth authored Sep 12, 2023
1 parent d1bcebf commit ea752a5
Show file tree
Hide file tree
Showing 14 changed files with 181 additions and 64 deletions.
15 changes: 10 additions & 5 deletions grai-server/app/api/pagination.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Callable, Generic, List, Optional, TypeVar
from typing import Any, Callable, Generic, List, Optional, TypeVar, Union

import strawberry
from django.db.models.query import QuerySet
from strawberry.field import StrawberryField
from strawberry_django.pagination import OffsetPaginationInput
from django.db.models.query import QuerySet
from django.db.models import Model

from .order import apply_order

Expand All @@ -25,7 +27,7 @@ def __init__(
self,
queryset: QuerySet,
filteredQueryset: Optional[QuerySet] = None,
apply_filters: Callable[[QuerySet], QuerySet] = None,
apply_filters: Union[Callable[[QuerySet], QuerySet], None] = None,
order: Optional[StrawberryField] = strawberry.UNSET,
pagination: Optional[OffsetPaginationInput] = strawberry.UNSET,
):
Expand Down Expand Up @@ -66,11 +68,14 @@ def apply_pagination(queryset: QuerySet, pagination: Optional[OffsetPaginationIn
return queryset


S = TypeVar("S")


@strawberry.type
class DataWrapper(Generic[T]):
def __init__(self, data: List[T]):
class DataWrapper(Generic[S]):
def __init__(self, data: List[S]):
self.data = data

@strawberry.django.field
def data(self) -> List[T]:
def data(self) -> List[S]:
return self.data
11 changes: 10 additions & 1 deletion grai-server/app/api/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,16 @@ def get_workspace(
user = get_user(info)

try:
query = {"id": id} if id else {"name": name, "organisation__name": organisationName}
query = None

if id:
query = {"id": id}
elif name and organisationName:
query = {"name": name, "organisation__name": organisationName}

if not query:
raise Exception("Can't find workspace")

workspace = WorkspaceModel.objects.get(**query, memberships__user_id=user.id)
except WorkspaceModel.DoesNotExist:
raise Exception("Can't find workspace")
Expand Down
55 changes: 55 additions & 0 deletions grai-server/app/api/tests/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,61 @@ async def test_workspace_get(test_context):
}


@pytest.mark.django_db
@pytest.mark.asyncio
async def test_workspace_name(test_context):
context, organisation, workspace, user, membership = test_context

query = """
query Workspace($name: String!, $organisationName: String!) {
workspace(name: $name, organisationName: $organisationName) {
id
name
}
}
"""

result = await schema.execute(
query,
variable_values={
"name": workspace.name,
"organisationName": organisation.name,
},
context_value=context,
)

assert result.errors is None
assert result.data["workspace"] == {
"id": str(workspace.id),
"name": workspace.name,
}


@pytest.mark.django_db
@pytest.mark.asyncio
async def test_workspace_none(test_context):
context, organisation, workspace, user, membership = test_context

query = """
query Workspace {
workspace {
id
name
}
}
"""

result = await schema.execute(
query,
context_value=context,
)

assert (
str(result.errors)
== """[GraphQLError("Can't find workspace", locations=[SourceLocation(line=3, column=13)], path=['workspace'])]"""
)


@pytest.mark.django_db
@pytest.mark.asyncio
async def test_workspace_no_workspace(test_context):
Expand Down
22 changes: 14 additions & 8 deletions grai-server/app/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
from collections import defaultdict
from enum import Enum
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Union

import strawberry
import strawberry_django
Expand All @@ -13,7 +13,6 @@
from grai_graph.graph import BaseSourceSegment
from notifications.models import Alert as AlertModel
from strawberry.scalars import JSON
from strawberry.types import Info
from strawberry_django.filters import FilterLookup
from strawberry_django.pagination import OffsetPaginationInput

Expand Down Expand Up @@ -892,15 +891,22 @@ async def graph(
query.match("(table:Table)")

if filters and filters.source_id:
return graph.get_source_filtered_graph_result(filters.source_id, filters.n)
return graph.get_source_filtered_graph_result(filters.source_id, filters.n or 0)

if filters and filters.table_id:
return graph.get_table_filtered_graph_result(filters.table_id, filters.n)
return graph.get_table_filtered_graph_result(filters.table_id, filters.n or 0)

if filters and filters.edge_id:
return graph.get_edge_filtered_graph_result(filters.edge_id, filters.n)

if filters and filters.min_x is not None and filters.max_x is not strawberry.UNSET:
return graph.get_edge_filtered_graph_result(filters.edge_id, filters.n or 0)

if (
filters
and filters.min_x is not None
and filters.min_x is not strawberry.UNSET
and filters.max_x is not None
and filters.min_y is not None
and filters.max_y is not None
):
graph.filter_by_range(filters.min_x, filters.max_x, filters.min_y, filters.max_y, query)

if filters and filters.filters:
Expand Down Expand Up @@ -962,7 +968,7 @@ def source(self, id: strawberry.ID) -> Source:
# Source Graph
@strawberry.field
def source_graph(self) -> List[SourceGraph]:
def fetch_source_graph(workspace: WorkspaceModel):
def fetch_source_graph(workspace: Workspace):
nodes = defaultdict(list)
for node in (
NodeModel.objects.filter(workspace=workspace, is_active=True)
Expand Down
2 changes: 1 addition & 1 deletion grai-server/app/lineage/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def view(self): # pragma: no cover
)

fields = ("name", view)
readonly_fields = (view,)
readonly_fields = [view]


class SourceAdmin(admin.ModelAdmin):
Expand Down
76 changes: 39 additions & 37 deletions grai-server/app/lineage/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,23 @@ def __str__(self) -> str:
return self.where


WhereType = Union[str, Where]
WhereArrayType = Union[WhereType, List[WhereType]]


def wrapWhere(input: WhereArrayType) -> List[Where]:
if isinstance(input, list):
return [Where(w) if isinstance(w, str) else w for w in input]

return [input if isinstance(input, Where) else Where(input)]


class Match:
def __init__(
self,
match: str,
optional: bool = False,
where: Union[str, Where, List[Where]] = None,
where: WhereArrayType | None = None,
parameters: object = None,
):
self.match = match
Expand All @@ -33,16 +44,16 @@ def __init__(
if isinstance(where, str):
where = Where(where)

self.wheres = wrap(where) if where else []
self.wheres = wrapWhere(where) if where else []
self.parameters = parameters if parameters else {}

def where(self, where: Union[str, Where, List[Where]]) -> "Match":
def where(self, where: WhereArrayType) -> "Match":
if isinstance(where, str):
self.wheres.append(Where(where))

return self

self.wheres.extend(wrap(where))
self.wheres.extend(wrapWhere(where))

return self

Expand All @@ -63,63 +74,54 @@ def get_parameters(self):
return res


Clause = Union[Match, str]
MatchType = Union[Match, str]
MatchTypeArray = Union[MatchType, List[MatchType]]


def wrapMatch(input: MatchTypeArray, optional: bool = False, where: WhereArrayType | None = None) -> List[Match]:
if isinstance(input, list):
return [Match(w, optional=optional, where=where) if isinstance(w, str) else w for w in input]

return [input if isinstance(input, Match) else Match(input, optional=optional, where=where)]


class GraphQuery:
def __init__(self, clause: Union[Clause, List[Clause]] = None, parameters: object = None):
def __init__(
self,
clause: MatchTypeArray | None = None,
parameters: object = None,
):
self.clause = wrap(clause) if clause else []
self.parameters = parameters if parameters else {}
self.withWheres = None
self.withWheres: str | None = None

def match(
self,
match: Union[str, Match, List[Match]],
where: Union[str, Where, List[Where]] = [],
match: MatchTypeArray,
where: WhereArrayType | None = None,
parameters: object = {},
) -> "GraphQuery":
self.parameters = self.parameters | parameters

if isinstance(match, Match):
self.clause.append(match)

return self

if isinstance(match, List):
self.clause.extend(match)

return self

self.clause.append(Match(match, where=where))
matches = wrapMatch(match, where=where)
self.clause.extend(matches)

return self

def optional_match(
self,
match: Union[str, Match, List[Match]],
where: Union[str, Where, List[Where]] = [],
match: MatchTypeArray,
where: WhereArrayType | None = None,
parameters: object = {},
) -> "GraphQuery":
self.parameters = self.parameters | parameters

if isinstance(match, Match):
match.optional = True
self.clause.append(match)

return self

if isinstance(match, List):
for m in match:
m.optional = True
self.clause.extend(match)

return self

self.clause.append(Match(match, optional=True, where=where))
matches = wrapMatch(match, optional=True, where=where)
self.clause.extend(matches)

return self

def where(self, where: Union[str, Where, List[Where]], parameters: object = {}) -> "GraphQuery":
def where(self, where: WhereArrayType, parameters: object = {}) -> "GraphQuery":
if isinstance(where, str):
where = Where(where, parameters)

Expand Down
6 changes: 5 additions & 1 deletion grai-server/app/lineage/graph_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, workspace: Union[Workspace, str]):
db=0,
)

def query(self, query: str, parameters: object = {}, timeout: int = None):
def query(self, query: str, parameters: object = {}, timeout: Optional[int] = None):
try:
return self.manager.graph(f"lineage:{str(self.workspace_id)}").query(query, parameters, timeout=timeout)
except redis.exceptions.ResponseError as e:
Expand Down Expand Up @@ -373,10 +373,14 @@ def filter_by_filters(self, filters, query: GraphQuery) -> GraphQuery:
for filter in filters:
query = filter_by_filter(filter, query)

return query

def filter_by_rows(self, filters, query: GraphQuery) -> GraphQuery:
for filter in filters:
query = filter_by_dict(filter, query)

return query

def get_with_step_graph_result(
self, n: int, parameters: object = {}, where: Optional[str] = None
) -> List["GraphTable"]:
Expand Down
12 changes: 6 additions & 6 deletions grai-server/app/lineage/managers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List, Sequence
from typing import Any, Iterable, List, Sequence

from django.db import models
from django_multitenant.models import TenantManagerMixin
Expand All @@ -10,13 +10,13 @@
class CacheManager(TenantManagerMixin, models.Manager):
def bulk_create(
self,
objs: Iterable,
batch_size: int = None,
objs: Iterable[Any],
batch_size: int | None = None,
**kwargs,
) -> List:
result = super().bulk_create(objs, **kwargs)

if len(objs) > 0:
if len(list(objs)) > 0:
workspace = objs[0].workspace
cache = GraphCache(workspace)

Expand All @@ -29,7 +29,7 @@ def bulk_create(

def bulk_update(
self,
objs: Iterable,
objs: Iterable[Any],
fields: Sequence[str],
**kwargs,
) -> int:
Expand All @@ -39,7 +39,7 @@ def bulk_update(
**kwargs,
)

if len(objs) > 0:
if len(list(objs)) > 0:
workspace = objs[0].workspace
cache = GraphCache(workspace)

Expand Down
2 changes: 1 addition & 1 deletion grai-server/app/lineage/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def set_names(self, *args, **kwargs):
self.display_name = self.name
return self

def cache_model(self, cache: GraphCache = None, delete: bool = False):
def cache_model(self, cache: GraphCache | None = None, delete: bool = False):
if cache:
cache.cache_node(self)
else:
Expand Down
Loading

0 comments on commit ea752a5

Please sign in to comment.