diff --git a/drf_spectacular/contrib/django_oauth_toolkit.py b/drf_spectacular/contrib/django_oauth_toolkit.py index 9b0019bc..93131775 100644 --- a/drf_spectacular/contrib/django_oauth_toolkit.py +++ b/drf_spectacular/contrib/django_oauth_toolkit.py @@ -3,25 +3,36 @@ class DjangoOAuthToolkitScheme(OpenApiAuthenticationExtension): target_class = 'oauth2_provider.contrib.rest_framework.OAuth2Authentication' - name = 'oauth2' + name: str = 'oauth2' def get_security_requirement(self, auto_schema): from oauth2_provider.contrib.rest_framework import ( IsAuthenticatedOrTokenHasScope, TokenHasScope, TokenMatchesOASRequirements, ) + from rest_framework.permissions import AND, OR view = auto_schema.view request = view.request - for permission in auto_schema.view.get_permissions(): - if isinstance(permission, TokenMatchesOASRequirements): - alt_scopes = permission.get_required_alternate_scopes(request, view) + def security_requirement_from_permission(perm) -> list | dict | None: + if isinstance(perm, (OR, AND)): + return ( + security_requirement_from_permission(perm.op1) or security_requirement_from_permission(perm.op2) + ) + if isinstance(perm, TokenMatchesOASRequirements): + alt_scopes = perm.get_required_alternate_scopes(request, view) alt_scopes = alt_scopes.get(auto_schema.method, []) return [{self.name: group} for group in alt_scopes] - if isinstance(permission, IsAuthenticatedOrTokenHasScope): + if isinstance(perm, IsAuthenticatedOrTokenHasScope): return {self.name: TokenHasScope().get_scopes(request, view)} - if isinstance(permission, TokenHasScope): + if isinstance(perm, TokenHasScope): # catch-all for subclasses of TokenHasScope like TokenHasReadWriteScope - return {self.name: permission.get_scopes(request, view)} + return {self.name: perm.get_scopes(request, view)} + return None + + security_requirements = map(security_requirement_from_permission, auto_schema.view.get_permissions()) + for requirement in security_requirements: + if requirement is not None: + return requirement def get_security_definition(self, auto_schema): from oauth2_provider.scopes import get_scopes_backend