Skip to content

Commit

Permalink
Add specific "token expired" exceptions (#830)
Browse files Browse the repository at this point in the history
* Add a specific backend exception for expired tokens

To later allow specific handling for this case in the layers above.

* Add a separate TokenError subclass for expired tokens

To allow the caller to handle expired tokens separately from invalid ones
without resorting to string matching.
  • Loading branch information
vainu-arto authored Jan 11, 2025
1 parent 9a66629 commit f602132
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 18 deletions.
17 changes: 12 additions & 5 deletions rest_framework_simplejwt/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@

import jwt
from django.utils.translation import gettext_lazy as _
from jwt import InvalidAlgorithmError, InvalidTokenError, algorithms

from .exceptions import TokenBackendError
from jwt import (
ExpiredSignatureError,
InvalidAlgorithmError,
InvalidTokenError,
algorithms,
)

from .exceptions import TokenBackendError, TokenBackendExpiredToken
from .tokens import Token
from .utils import format_lazy

Expand Down Expand Up @@ -101,7 +106,7 @@ def get_verifying_key(self, token: Token) -> Optional[str]:
try:
return self.jwks_client.get_signing_key_from_jwt(token).key
except PyJWKClientError as ex:
raise TokenBackendError(_("Token is invalid or expired")) from ex
raise TokenBackendError(_("Token is invalid")) from ex

return self.verifying_key

Expand Down Expand Up @@ -150,5 +155,7 @@ def decode(self, token: Token, verify: bool = True) -> Dict[str, Any]:
)
except InvalidAlgorithmError as ex:
raise TokenBackendError(_("Invalid algorithm specified")) from ex
except ExpiredSignatureError as ex:
raise TokenBackendExpiredToken(_("Token is expired")) from ex
except InvalidTokenError as ex:
raise TokenBackendError(_("Token is invalid or expired")) from ex
raise TokenBackendError(_("Token is invalid")) from ex
8 changes: 8 additions & 0 deletions rest_framework_simplejwt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,18 @@ class TokenError(Exception):
pass


class ExpiredTokenError(TokenError):
pass


class TokenBackendError(Exception):
pass


class TokenBackendExpiredToken(TokenBackendError):
pass


class DetailDictMixin:
default_detail: str
default_code: str
Expand Down
11 changes: 9 additions & 2 deletions rest_framework_simplejwt/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from django.utils.module_loading import import_string
from django.utils.translation import gettext_lazy as _

from .exceptions import TokenBackendError, TokenError
from .exceptions import (
ExpiredTokenError,
TokenBackendError,
TokenBackendExpiredToken,
TokenError,
)
from .models import TokenUser
from .settings import api_settings
from .token_blacklist.models import BlacklistedToken, OutstandingToken
Expand Down Expand Up @@ -56,8 +61,10 @@ def __init__(self, token: Optional["Token"] = None, verify: bool = True) -> None
# Decode token
try:
self.payload = token_backend.decode(token, verify=verify)
except TokenBackendExpiredToken:
raise ExpiredTokenError(_("Token is expired"))
except TokenBackendError:
raise TokenError(_("Token is invalid or expired"))
raise TokenError(_("Token is invalid"))

if verify:
self.verify()
Expand Down
11 changes: 6 additions & 5 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from jwt import algorithms

from rest_framework_simplejwt.backends import JWK_CLIENT_AVAILABLE, TokenBackend
from rest_framework_simplejwt.exceptions import TokenBackendError
from rest_framework_simplejwt.exceptions import (
TokenBackendError,
TokenBackendExpiredToken,
)
from rest_framework_simplejwt.utils import aware_utcnow, datetime_to_epoch, make_utc
from tests.keys import (
ES256_PRIVATE_KEY,
Expand Down Expand Up @@ -191,7 +194,7 @@ def test_decode_with_expiry(self):
self.payload, backend.signing_key, algorithm=backend.algorithm
)

with self.assertRaises(TokenBackendError):
with self.assertRaises(TokenBackendExpiredToken):
backend.decode(expired_token)

def test_decode_with_invalid_sig(self):
Expand Down Expand Up @@ -346,9 +349,7 @@ def test_decode_jwk_missing_key_raises_tokenbackenderror(self):
"RS256", PRIVATE_KEY, PUBLIC_KEY, AUDIENCE, ISSUER, JWK_URL
)

with self.assertRaisesRegex(
TokenBackendError, "Token is invalid or expired"
):
with self.assertRaisesRegex(TokenBackendError, "Token is invalid"):
jwk_token_backend.decode(token)

def test_decode_when_algorithm_not_available(self):
Expand Down
8 changes: 4 additions & 4 deletions tests/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def test_it_should_not_validate_if_token_invalid(self):
with self.assertRaises(TokenError) as e:
s.is_valid()

self.assertIn("invalid or expired", e.exception.args[0])
self.assertIn("expired", e.exception.args[0])

def test_it_should_raise_token_error_if_token_has_no_refresh_exp_claim(self):
token = SlidingToken()
Expand Down Expand Up @@ -337,7 +337,7 @@ def test_it_should_raise_token_error_if_token_invalid(self):
with self.assertRaises(TokenError) as e:
s.is_valid()

self.assertIn("invalid or expired", e.exception.args[0])
self.assertIn("expired", e.exception.args[0])

def test_it_should_raise_token_error_if_token_has_wrong_type(self):
token = RefreshToken()
Expand Down Expand Up @@ -503,7 +503,7 @@ def test_it_should_raise_token_error_if_token_invalid(self):
with self.assertRaises(TokenError) as e:
s.is_valid()

self.assertIn("invalid or expired", e.exception.args[0])
self.assertIn("expired", e.exception.args[0])

def test_it_should_not_raise_token_error_if_token_has_wrong_type(self):
token = RefreshToken()
Expand Down Expand Up @@ -548,7 +548,7 @@ def test_it_should_raise_token_error_if_token_invalid(self):
with self.assertRaises(TokenError) as e:
s.is_valid()

self.assertIn("invalid or expired", e.exception.args[0])
self.assertIn("expired", e.exception.args[0])

def test_it_should_raise_token_error_if_token_has_wrong_type(self):
token = RefreshToken()
Expand Down
8 changes: 6 additions & 2 deletions tests/test_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from django.test import TestCase
from jose import jwt

from rest_framework_simplejwt.exceptions import TokenBackendError, TokenError
from rest_framework_simplejwt.exceptions import (
ExpiredTokenError,
TokenBackendError,
TokenError,
)
from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.state import token_backend
from rest_framework_simplejwt.tokens import (
Expand Down Expand Up @@ -157,7 +161,7 @@ def test_init_expired_token_given(self):
t = MyToken()
t.set_exp(lifetime=-timedelta(seconds=1))

with self.assertRaises(TokenError):
with self.assertRaises(ExpiredTokenError):
MyToken(str(t))

def test_init_no_type_token_given(self):
Expand Down

0 comments on commit f602132

Please sign in to comment.