diff --git a/README.md b/README.md index c238a7d..5023fce 100644 --- a/README.md +++ b/README.md @@ -94,10 +94,25 @@ The user name field in the JWT token payload: app.add_middleware(AuthenticationMiddleware, backend=JWTAuthenticationBackend(secret_key='secret', username_field='user')) ``` +*audience* + +The audience field in the JWT token is validated: +```python +# Example: changes the username field to "user" +app.add_middleware(AuthenticationMiddleware, backend=JWTAuthenticationBackend(secret_key='secret', username_field='user', audience='test_aud')) +``` + +*options* + +The options set to ignore audience verification: +```python +# Example: changes the username field to "user" +app.add_middleware(AuthenticationMiddleware, backend=JWTAuthenticationBackend(secret_key='secret', username_field='user', options={"verify_aud": False})) +``` + ## Todo * Support JWT token standard payload -* Set JWT options (time expiration for example) ## Developing diff --git a/starlette_jwt/middleware.py b/starlette_jwt/middleware.py index d8e590a..1cf397c 100644 --- a/starlette_jwt/middleware.py +++ b/starlette_jwt/middleware.py @@ -61,11 +61,14 @@ async def authenticate(self, request): class JWTWebSocketAuthenticationBackend(AuthenticationBackend): def __init__(self, secret_key: str, algorithm: str = 'HS256', query_param_name: str = 'jwt', - username_field: str = 'username'): + username_field: str = 'username', audience = None, options = {}): self.secret_key = secret_key self.algorithm = algorithm self.query_param_name = query_param_name self.username_field = username_field + self.audience = audience + self.options = options + async def authenticate(self, request): if self.query_param_name not in request.query_params: @@ -74,7 +77,8 @@ async def authenticate(self, request): token = request.query_params[self.query_param_name] try: - payload = jwt.decode(token, key=self.secret_key, algorithms=self.algorithm) + payload = jwt.decode(token, key=self.secret_key, algorithms=self.algorithm, audience=self.audience, + options=self.options) except jwt.InvalidTokenError as e: raise AuthenticationError(str(e)) diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 34b2be2..db8e388 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -78,6 +78,48 @@ def test_websocket_valid_authentication(): assert websocket.scope['user'].is_authenticated +def test_websocket_valid_authentication_and_audience(): + secret_key = 'example' + app = create_app() + app.add_middleware(AuthenticationMiddleware, backend=JWTWebSocketAuthenticationBackend(secret_key=secret_key, + audience="test_aud")) + client = TestClient(app) + token = jwt.encode(dict(username="user", aud="test_aud"), secret_key, algorithm="HS256").decode() + with client.websocket_connect(f"/ws-auth?jwt={token}") as websocket: + data = websocket.receive_text() + assert data == 'Authentication valid' + assert websocket.scope['user'].is_authenticated + + +def test_websocket_valid_authentication_and_audience_list(): + secret_key = 'example' + app = create_app() + app.add_middleware(AuthenticationMiddleware, backend=JWTWebSocketAuthenticationBackend(secret_key=secret_key, + audience=["test_aud"])) + client = TestClient(app) + token = jwt.encode(dict(username="user", aud="test_aud"), secret_key, algorithm="HS256").decode() + with client.websocket_connect(f"/ws-auth?jwt={token}") as websocket: + data = websocket.receive_text() + assert data == 'Authentication valid' + assert websocket.scope['user'].is_authenticated + + +def test_websocket_valid_authentication_and_audience_and_option_ignore_audience(): + secret_key = 'example' + app = create_app() + options = {"verify_aud": False} + app.add_middleware(AuthenticationMiddleware, backend=JWTWebSocketAuthenticationBackend(secret_key=secret_key, + audience="test_aud", + options=options)) + client = TestClient(app) + token = jwt.encode(dict(username="user"), secret_key, algorithm="HS256", + ).decode() + with client.websocket_connect(f"/ws-auth?jwt={token}") as websocket: + data = websocket.receive_text() + assert data == 'Authentication valid' + assert websocket.scope['user'].is_authenticated + + def test_websocket_invalid_token(): secret_key = 'example' app = create_app()