Skip to content

Commit

Permalink
ProductsQueryset manager execute duplicated channel queries (saleor#1…
Browse files Browse the repository at this point in the history
…5858)

* ProductsQueryset manager execute duplicated channel queries (saleor#15787)

* Optmize ProductsQueryset.published function

* Optimize ProductsQueryset.not_published function

* Optmize ProductsQueryset.published_with_variants function

* Extend visible_to_user_test

* Extend test for fetching products via federated query

* Extend test for fetching variants via federated query

* Optimize ProductsQueryset.visible_to_user function

* Optimize ProductsQueryset.annotate_publication_info function

* Optimize ProductsQueryset.annotate_is_published function

* Optimize ProductsQueryset.annotate_published_at function

* Optimize ProductsQueryset.annotate_visible_in_listings function

* Optimize ProductVariantQueryset.available_in_channel function

* Rename aargument of visible_to_user

* Add missing using to querysets

* Fix docstring
  • Loading branch information
fowczarek authored Apr 30, 2024
1 parent 899bfc2 commit fa9ea3a
Show file tree
Hide file tree
Showing 13 changed files with 778 additions and 232 deletions.
9 changes: 7 additions & 2 deletions saleor/graphql/attribute/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from ...attribute import AttributeInputType
from ...attribute.models import Attribute, AttributeValue
from ...channel.models import Channel
from ...permission.utils import has_one_of_permissions
from ...product import models
from ...product.models import ALL_PRODUCTS_PERMISSIONS
Expand Down Expand Up @@ -42,8 +43,12 @@ def filter_attributes_by_product_types(qs, field, value, requestor, channel_slug
if not value:
return qs

channel = None
if channel_slug is not None:
channel = Channel.objects.using(qs.db).filter(slug=str(channel_slug)).first()
limited_channel_access = False if channel_slug is None else True
product_qs = models.Product.objects.using(qs.db).visible_to_user(
requestor, channel_slug
requestor, channel, limited_channel_access
)

if field == "in_category":
Expand All @@ -57,7 +62,7 @@ def filter_attributes_by_product_types(qs, field, value, requestor, channel_slug
product_qs = product_qs.filter(category__in=tree)

if not has_one_of_permissions(requestor, ALL_PRODUCTS_PERMISSIONS):
product_qs = product_qs.annotate_visible_in_listings(channel_slug).exclude(
product_qs = product_qs.annotate_visible_in_listings(channel).exclude(
visible_in_listings=False
)

Expand Down
10 changes: 5 additions & 5 deletions saleor/graphql/order/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def validate_product_is_published(
unpublished_product = (
Product.objects.using(database_connection_name)
.filter(variants__id__in=variant_ids)
.not_published(channel.slug)
.not_published(channel)
)
if unpublished_product.exists():
errors["lines"].append(
Expand Down Expand Up @@ -238,7 +238,7 @@ def validate_product_is_published_in_channel(
unpublished_product = list(
Product.objects.using(database_connection_name)
.filter(variants__id__in=variant_ids)
.not_published(channel.slug)
.not_published(channel)
)
if unpublished_product:
unpublished_variants = (
Expand Down Expand Up @@ -398,9 +398,9 @@ def prepare_insufficient_stock_order_validation_errors(exc):
"Insufficient product stock.",
code=OrderErrorCode.INSUFFICIENT_STOCK.value,
params={
"order_lines": [order_line_global_id]
if order_line_global_id
else [],
"order_lines": (
[order_line_global_id] if order_line_global_id else []
),
"warehouse": warehouse_global_id,
},
)
Expand Down
43 changes: 27 additions & 16 deletions saleor/graphql/product/resolvers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from django.db.models import Exists, OuterRef, Sum

from ...channel.models import Channel
Expand Down Expand Up @@ -66,11 +68,17 @@ def resolve_digital_contents(info: ResolveInfo):


def resolve_product(
info: ResolveInfo, id, slug, external_reference, channel_slug, requestor
info: ResolveInfo,
id,
slug,
external_reference,
channel: Optional[Channel],
limited_channel_access: bool,
requestor,
):
database_connection_name = get_database_connection_name(info.context)
qs = models.Product.objects.using(database_connection_name).visible_to_user(
requestor, channel_slug=channel_slug
requestor, channel, limited_channel_access
)
if id:
_type, id = from_global_id_or_error(id, "Product")
Expand All @@ -83,18 +91,17 @@ def resolve_product(

@traced_resolver
def resolve_products(
info: ResolveInfo, requestor, channel_slug=None
info: ResolveInfo,
requestor,
channel: Optional[Channel],
limited_channel_access: bool,
) -> ChannelQsContext:
connection_name = get_database_connection_name(info.context)
qs = models.Product.objects.using(connection_name).visible_to_user(
requestor, channel_slug
requestor, channel, limited_channel_access
)
if not has_one_of_permissions(requestor, ALL_PRODUCTS_PERMISSIONS):
if channel := (
Channel.objects.using(connection_name)
.filter(slug=str(channel_slug))
.first()
):
if channel:
product_channel_listings = (
models.ProductChannelListing.objects.using(connection_name)
.filter(channel_id=channel.id, visible_in_listings=True)
Expand All @@ -105,6 +112,7 @@ def resolve_products(
)
else:
qs = models.Product.objects.none()
channel_slug = channel.slug if channel else None
return ChannelQsContext(qs=qs, channel_slug=channel_slug)


Expand All @@ -129,21 +137,22 @@ def resolve_variant(
sku,
external_reference,
*,
channel_slug,
channel: Optional[Channel],
limited_channel_access: bool,
requestor,
requestor_has_access_to_all,
):
connection_name = get_database_connection_name(info.context)
visible_products = (
models.Product.objects.using(connection_name)
.visible_to_user(requestor, channel_slug)
.visible_to_user(requestor, channel, limited_channel_access)
.values_list("pk", flat=True)
)
qs = models.ProductVariant.objects.using(connection_name).filter(
product__id__in=visible_products
)
if not requestor_has_access_to_all:
qs = qs.available_in_channel(channel_slug)
qs = qs.available_in_channel(channel)
if id:
_, id = from_global_id_or_error(id, "ProductVariant")
return qs.filter(pk=id).first()
Expand All @@ -159,24 +168,26 @@ def resolve_product_variants(
requestor_has_access_to_all,
requestor,
ids=None,
channel_slug=None,
channel: Optional[Channel] = None,
limited_channel_access: bool = False,
) -> ChannelQsContext:
connection_name = get_database_connection_name(info.context)
visible_products = models.Product.objects.using(connection_name).visible_to_user(
requestor, channel_slug
requestor, channel, limited_channel_access
)
qs = models.ProductVariant.objects.using(connection_name).filter(
product__id__in=visible_products
)

channel_slug = channel.slug if channel else None
if not requestor_has_access_to_all:
visible_products = visible_products.annotate_visible_in_listings(
channel_slug
channel
).exclude(visible_in_listings=False)
qs = (
qs.using(connection_name)
.filter(product__in=visible_products)
.available_in_channel(channel_slug)
.available_in_channel(channel)
)
if ids:
db_ids = [
Expand Down
137 changes: 95 additions & 42 deletions saleor/graphql/product/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ...product.models import ALL_PRODUCTS_PERMISSIONS
from ...product.search import search_products
from ..channel import ChannelContext, ChannelQsContext
from ..channel.dataloaders import ChannelBySlugLoader
from ..channel.utils import get_default_channel_slug_or_graphql_error
from ..core import ResolveInfo
from ..core.connection import create_connection_slice, filter_connection_queryset
Expand Down Expand Up @@ -436,21 +437,35 @@ def resolve_product(
requestor, ALL_PRODUCTS_PERMISSIONS
)

limited_channel_access = False if channel is None else True
if channel is None and not has_required_permissions:
channel = get_default_channel_slug_or_graphql_error(
allow_replica=info.context.allow_replica
)

product = resolve_product(
info,
id=id,
slug=slug,
external_reference=external_reference,
channel_slug=channel,
requestor=requestor,
)
def _resolve_product(channel_obj):
product = resolve_product(
info,
id=id,
slug=slug,
external_reference=external_reference,
channel=channel_obj,
limited_channel_access=limited_channel_access,
requestor=requestor,
)

return ChannelContext(node=product, channel_slug=channel) if product else None
return (
ChannelContext(node=product, channel_slug=channel) if product else None
)

if channel:
return (
ChannelBySlugLoader(info.context)
.load(str(channel))
.then(_resolve_product)
)
else:
return _resolve_product(None)

@staticmethod
@traced_resolver
Expand All @@ -462,20 +477,32 @@ def resolve_products(_root, info: ResolveInfo, *, channel=None, **kwargs):
has_required_permissions = has_one_of_permissions(
requestor, ALL_PRODUCTS_PERMISSIONS
)
limited_channel_access = False if channel is None else True
if channel is None and not has_required_permissions:
channel = get_default_channel_slug_or_graphql_error(
allow_replica=info.context.allow_replica
)
qs = resolve_products(info, requestor, channel_slug=channel)
if search:
qs = ChannelQsContext(
qs=search_products(qs.qs, search), channel_slug=channel

def _resolve_products(channel_obj):
qs = resolve_products(info, requestor, channel_obj, limited_channel_access)
if search:
qs = ChannelQsContext(
qs=search_products(qs.qs, search), channel_slug=channel
)
kwargs["channel"] = channel
qs = filter_connection_queryset(
qs, kwargs, allow_replica=info.context.allow_replica
)
kwargs["channel"] = channel
qs = filter_connection_queryset(
qs, kwargs, allow_replica=info.context.allow_replica
)
return create_connection_slice(qs, info, kwargs, ProductCountableConnection)
return create_connection_slice(qs, info, kwargs, ProductCountableConnection)

if channel:
return (
ChannelBySlugLoader(info.context)
.load(str(channel))
.then(_resolve_products)
)
else:
return _resolve_products(None)

@staticmethod
def resolve_product_type(_root, info: ResolveInfo, *, id):
Expand Down Expand Up @@ -509,22 +536,35 @@ def resolve_product_variant(
requestor, ALL_PRODUCTS_PERMISSIONS
)

limited_channel_access = False if channel is None else True
if channel is None and not has_required_permissions:
channel = get_default_channel_slug_or_graphql_error(
allow_replica=info.context.allow_replica
)

variant = resolve_variant(
info,
id,
sku,
external_reference,
channel_slug=channel,
requestor=requestor,
requestor_has_access_to_all=has_required_permissions,
)
def _resolve_product_variant(channel_obj):
variant = resolve_variant(
info,
id,
sku,
external_reference,
channel=channel_obj,
limited_channel_access=limited_channel_access,
requestor=requestor,
requestor_has_access_to_all=has_required_permissions,
)
return (
ChannelContext(node=variant, channel_slug=channel) if variant else None
)

return ChannelContext(node=variant, channel_slug=channel) if variant else None
if channel:
return (
ChannelBySlugLoader(info.context)
.load(str(channel))
.then(_resolve_product_variant)
)
else:
return _resolve_product_variant(None)

@staticmethod
def resolve_product_variants(
Expand All @@ -534,24 +574,37 @@ def resolve_product_variants(
has_required_permissions = has_one_of_permissions(
requestor, ALL_PRODUCTS_PERMISSIONS
)
limited_channel_access = False if channel is None else True
if channel is None and not has_required_permissions:
channel = get_default_channel_slug_or_graphql_error(
allow_replica=info.context.allow_replica
)
qs = resolve_product_variants(
info,
ids=ids,
channel_slug=channel,
requestor_has_access_to_all=has_required_permissions,
requestor=requestor,
)
kwargs["channel"] = qs.channel_slug
qs = filter_connection_queryset(
qs, kwargs, allow_replica=info.context.allow_replica
)
return create_connection_slice(
qs, info, kwargs, ProductVariantCountableConnection
)

def _resolve_product_variants(channel_obj):
qs = resolve_product_variants(
info,
ids=ids,
channel=channel_obj,
limited_channel_access=limited_channel_access,
requestor_has_access_to_all=has_required_permissions,
requestor=requestor,
)
kwargs["channel"] = qs.channel_slug
qs = filter_connection_queryset(
qs, kwargs, allow_replica=info.context.allow_replica
)
return create_connection_slice(
qs, info, kwargs, ProductVariantCountableConnection
)

if channel:
return (
ChannelBySlugLoader(info.context)
.load(str(channel))
.then(_resolve_product_variants)
)
else:
return _resolve_product_variants(None)

@staticmethod
@traced_resolver
Expand Down
4 changes: 2 additions & 2 deletions saleor/graphql/product/tests/benchmark/test_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ def test_products_for_federation_query_count(
],
}

with django_assert_num_queries(5):
with django_assert_num_queries(3):
response = api_client.post_graphql(query, variables)
content = get_graphql_content(response)
assert len(content["data"]["_entities"]) == 1
Expand All @@ -765,7 +765,7 @@ def test_products_for_federation_query_count(
],
}

with django_assert_num_queries(5):
with django_assert_num_queries(3):
response = api_client.post_graphql(query, variables)
content = get_graphql_content(response)
assert len(content["data"]["_entities"]) == 2
Expand Down
Loading

0 comments on commit fa9ea3a

Please sign in to comment.