diff --git a/rest_framework_simplejwt/backends.py b/rest_framework_simplejwt/backends.py index 0e3537cf0..c81ba9cc7 100644 --- a/rest_framework_simplejwt/backends.py +++ b/rest_framework_simplejwt/backends.py @@ -5,9 +5,14 @@ import jwt from django.utils.translation import gettext_lazy as _ -from jwt import InvalidAlgorithmError, InvalidTokenError, algorithms - -from .exceptions import TokenBackendError +from jwt import ( + ExpiredSignatureError, + InvalidAlgorithmError, + InvalidTokenError, + algorithms, +) + +from .exceptions import TokenBackendError, TokenBackendExpiredToken from .tokens import Token from .utils import format_lazy @@ -101,7 +106,7 @@ def get_verifying_key(self, token: Token) -> Optional[str]: try: return self.jwks_client.get_signing_key_from_jwt(token).key except PyJWKClientError as ex: - raise TokenBackendError(_("Token is invalid or expired")) from ex + raise TokenBackendError(_("Token is invalid")) from ex return self.verifying_key @@ -150,5 +155,7 @@ def decode(self, token: Token, verify: bool = True) -> Dict[str, Any]: ) except InvalidAlgorithmError as ex: raise TokenBackendError(_("Invalid algorithm specified")) from ex + except ExpiredSignatureError as ex: + raise TokenBackendExpiredToken(_("Token is expired")) from ex except InvalidTokenError as ex: - raise TokenBackendError(_("Token is invalid or expired")) from ex + raise TokenBackendError(_("Token is invalid")) from ex diff --git a/rest_framework_simplejwt/exceptions.py b/rest_framework_simplejwt/exceptions.py index 882635215..5db017bbe 100644 --- a/rest_framework_simplejwt/exceptions.py +++ b/rest_framework_simplejwt/exceptions.py @@ -8,10 +8,18 @@ class TokenError(Exception): pass +class ExpiredTokenError(TokenError): + pass + + class TokenBackendError(Exception): pass +class TokenBackendExpiredToken(TokenBackendError): + pass + + class DetailDictMixin: default_detail: str default_code: str diff --git a/rest_framework_simplejwt/tokens.py b/rest_framework_simplejwt/tokens.py index 9e9c3b9df..01226af49 100644 --- a/rest_framework_simplejwt/tokens.py +++ b/rest_framework_simplejwt/tokens.py @@ -7,7 +7,12 @@ from django.utils.module_loading import import_string from django.utils.translation import gettext_lazy as _ -from .exceptions import TokenBackendError, TokenError +from .exceptions import ( + ExpiredTokenError, + TokenBackendError, + TokenBackendExpiredToken, + TokenError, +) from .models import TokenUser from .settings import api_settings from .token_blacklist.models import BlacklistedToken, OutstandingToken @@ -56,8 +61,10 @@ def __init__(self, token: Optional["Token"] = None, verify: bool = True) -> None # Decode token try: self.payload = token_backend.decode(token, verify=verify) + except TokenBackendExpiredToken: + raise ExpiredTokenError(_("Token is expired")) except TokenBackendError: - raise TokenError(_("Token is invalid or expired")) + raise TokenError(_("Token is invalid")) if verify: self.verify() diff --git a/tests/test_backends.py b/tests/test_backends.py index fd19183e0..4954588db 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -14,7 +14,10 @@ from jwt import algorithms from rest_framework_simplejwt.backends import JWK_CLIENT_AVAILABLE, TokenBackend -from rest_framework_simplejwt.exceptions import TokenBackendError +from rest_framework_simplejwt.exceptions import ( + TokenBackendError, + TokenBackendExpiredToken, +) from rest_framework_simplejwt.utils import aware_utcnow, datetime_to_epoch, make_utc from tests.keys import ( ES256_PRIVATE_KEY, @@ -191,7 +194,7 @@ def test_decode_with_expiry(self): self.payload, backend.signing_key, algorithm=backend.algorithm ) - with self.assertRaises(TokenBackendError): + with self.assertRaises(TokenBackendExpiredToken): backend.decode(expired_token) def test_decode_with_invalid_sig(self): @@ -346,9 +349,7 @@ def test_decode_jwk_missing_key_raises_tokenbackenderror(self): "RS256", PRIVATE_KEY, PUBLIC_KEY, AUDIENCE, ISSUER, JWK_URL ) - with self.assertRaisesRegex( - TokenBackendError, "Token is invalid or expired" - ): + with self.assertRaisesRegex(TokenBackendError, "Token is invalid"): jwk_token_backend.decode(token) def test_decode_when_algorithm_not_available(self): diff --git a/tests/test_serializers.py b/tests/test_serializers.py index 0fe03a9f1..d7b287856 100644 --- a/tests/test_serializers.py +++ b/tests/test_serializers.py @@ -206,7 +206,7 @@ def test_it_should_not_validate_if_token_invalid(self): with self.assertRaises(TokenError) as e: s.is_valid() - self.assertIn("invalid or expired", e.exception.args[0]) + self.assertIn("expired", e.exception.args[0]) def test_it_should_raise_token_error_if_token_has_no_refresh_exp_claim(self): token = SlidingToken() @@ -337,7 +337,7 @@ def test_it_should_raise_token_error_if_token_invalid(self): with self.assertRaises(TokenError) as e: s.is_valid() - self.assertIn("invalid or expired", e.exception.args[0]) + self.assertIn("expired", e.exception.args[0]) def test_it_should_raise_token_error_if_token_has_wrong_type(self): token = RefreshToken() @@ -503,7 +503,7 @@ def test_it_should_raise_token_error_if_token_invalid(self): with self.assertRaises(TokenError) as e: s.is_valid() - self.assertIn("invalid or expired", e.exception.args[0]) + self.assertIn("expired", e.exception.args[0]) def test_it_should_not_raise_token_error_if_token_has_wrong_type(self): token = RefreshToken() @@ -548,7 +548,7 @@ def test_it_should_raise_token_error_if_token_invalid(self): with self.assertRaises(TokenError) as e: s.is_valid() - self.assertIn("invalid or expired", e.exception.args[0]) + self.assertIn("expired", e.exception.args[0]) def test_it_should_raise_token_error_if_token_has_wrong_type(self): token = RefreshToken() diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 47702e33a..5605bae08 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -6,7 +6,11 @@ from django.test import TestCase from jose import jwt -from rest_framework_simplejwt.exceptions import TokenBackendError, TokenError +from rest_framework_simplejwt.exceptions import ( + ExpiredTokenError, + TokenBackendError, + TokenError, +) from rest_framework_simplejwt.settings import api_settings from rest_framework_simplejwt.state import token_backend from rest_framework_simplejwt.tokens import ( @@ -157,7 +161,7 @@ def test_init_expired_token_given(self): t = MyToken() t.set_exp(lifetime=-timedelta(seconds=1)) - with self.assertRaises(TokenError): + with self.assertRaises(ExpiredTokenError): MyToken(str(t)) def test_init_no_type_token_given(self):