From 9ef0b6d77ac33447c2f24769dc4e9e4a348c1aa0 Mon Sep 17 00:00:00 2001 From: Jenn Mueng <30991498+jennmueng@users.noreply.github.com> Date: Wed, 4 Sep 2024 11:50:01 -0700 Subject: [PATCH] feat(security): Add API bearer token auth via JWT (#1063) Mainly for the staging deployment in `ml-ai` Introduces an `ENFORCE_API_AUTH` env variable that will enforce either the rpc secret signing or the bearer token. --- docker-compose.yml | 1 + requirements.txt | 2 + src/seer/configuration.py | 10 +++ src/seer/json_api.py | 86 +++++++++++++++++++-- tests/test_json_api.py | 153 +++++++++++++++++++++++++++++++++++++- 5 files changed, 244 insertions(+), 8 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 37315eb5..e00f9155 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -60,6 +60,7 @@ services: - GOOGLE_APPLICATION_CREDENTIALS=/root/.config/gcloud/application_default_credentials.json - GOOGLE_CLOUD_PROJECT_ID=ml-ai-420606 - USE_EU_REGION=0 + - IGNORE_API_AUTH=1 ports: - "9091:9091" # Local dev sentry app looks for port 9091 for the seer service. volumes: diff --git a/requirements.txt b/requirements.txt index b9f5d867..5cdff4d1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -100,8 +100,10 @@ pydantic-xml==2.9.0 chromadb==0.4.14 google-cloud-storage==2.16.0 google-cloud-aiplatform==1.60.0 +google-cloud-secret-manager==2.20.2 anthropic[vertex]==0.31.2 langfuse @ git+https://github.com/jennmueng/langfuse-python.git@9d9350de1e4e84fa548fe84f82c1b826be17956e watchdog stumpy==1.13.0 pytest_alembic==0.11.1 +cryptography==43.0.0 diff --git a/src/seer/configuration.py b/src/seer/configuration.py index 27842da9..97d79439 100644 --- a/src/seer/configuration.py +++ b/src/seer/configuration.py @@ -54,7 +54,12 @@ class AppConfig(BaseModel): LANGFUSE_SECRET_KEY: str = "" LANGFUSE_HOST: str = "" + API_PUBLIC_KEY_SECRET_ID: str = "" JSON_API_SHARED_SECRETS: ParseList = Field(default_factory=list) + IGNORE_API_AUTH: ParseBool = False # Used for both API Tokens and RPC Secrets + + GOOGLE_CLOUD_PROJECT_ID: str = "" + TORCH_NUM_THREADS: ParseInt = 0 NO_SENTRY_INTEGRATION: ParseBool = False DEV: ParseBool = False @@ -74,6 +79,11 @@ def has_sentry_integration(self) -> bool: return not self.NO_SENTRY_INTEGRATION def do_validation(self): + if not self.IGNORE_API_AUTH: + assert ( + self.JSON_API_SHARED_SECRETS or self.API_PUBLIC_KEY_SECRET_ID + ), "JSON_API_SHARED_SECRETS or API_PUBLIC_KEY_SECRET_ID required if IGNORE_API_AUTH is false!" + if self.is_production: assert self.has_sentry_integration, "Sentry integration required for production mode." assert self.SENTRY_DSN, "SENTRY_DSN required for production!" diff --git a/src/seer/json_api.py b/src/seer/json_api.py index 8a8634f6..7b06d475 100644 --- a/src/seer/json_api.py +++ b/src/seer/json_api.py @@ -5,21 +5,70 @@ import logging from typing import Any, Callable, Type, TypeVar, get_type_hints +import jwt import sentry_sdk +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization from flask import Blueprint, request +from google.cloud import secretmanager from pydantic import BaseModel, ValidationError -from werkzeug.exceptions import BadRequest, Unauthorized +from werkzeug.exceptions import BadRequest, InternalServerError, Unauthorized +from seer.bootup import module, stub_module from seer.configuration import AppConfig from seer.dependency_injection import inject, injected logger = logging.getLogger(__name__) + _F = TypeVar("_F", bound=Callable[..., Any]) +def access_secret(project_id: str, secret_id: str, version_id: str = "latest"): + client = secretmanager.SecretManagerServiceClient() + name = f"projects/{project_id}/secrets/{secret_id}/versions/{version_id}" + response = client.access_secret_version(request={"name": name}) + return response.payload.data.decode("UTF-8") + + +def get_public_key_from_secret(project_id: str, secret_id: str, version_id: str = "latest"): + pem_data = access_secret(project_id, secret_id, version_id) + public_key = serialization.load_pem_public_key(pem_data.encode(), backend=default_backend()) + public_key_bytes = public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + return public_key_bytes + + +class PublicKeyBytes(BaseModel): + bytes: bytes | None + + +@module.provider +def provide_public_key(config: AppConfig = injected) -> PublicKeyBytes: + return PublicKeyBytes( + bytes=( + get_public_key_from_secret( + config.GOOGLE_CLOUD_PROJECT_ID, config.API_PUBLIC_KEY_SECRET_ID + ) + if config.GOOGLE_CLOUD_PROJECT_ID and config.API_PUBLIC_KEY_SECRET_ID + else None + ) + ) + + +@stub_module.provider +def provide_public_key_stub() -> PublicKeyBytes: + return PublicKeyBytes(bytes=None) + + def json_api(blueprint: Blueprint, url_rule: str) -> Callable[[_F], _F]: - def decorator(implementation: _F) -> _F: + @inject + def decorator( + implementation: _F, config: AppConfig = injected, public_key: PublicKeyBytes = injected + ) -> _F: spec = inspect.getfullargspec(implementation) annotations = get_type_hints(implementation) try: @@ -43,10 +92,35 @@ def wrapper(config: AppConfig = injected) -> Any: parts = auth_header.split() if len(parts) != 2 or not compare_signature(request.url, raw_data, parts[1]): raise Unauthorized("Rpcsignature did not match for given url and data") - else: - if config.is_production: - logger.warning(f"Found unexpected authorization header: {auth_header}") - raise Unauthorized("Rpcsignature was not included in authorization header!") + elif auth_header.startswith("Bearer "): + token = auth_header.split()[1] + try: + if public_key.bytes is None: + raise Unauthorized("Public key is not available") + # Verify the JWT token using PyJWT + jwt.decode(token, public_key.bytes, algorithms=["RS256"]) + + # Optionally, you can add additional checks here + # For example, checking the 'exp' claim for token expiration + # or verifying specific claims in the token payload + + # If the token is successfully decoded and verified, + # the function will continue execution + except jwt.ExpiredSignatureError: + raise Unauthorized("Token has expired") + except jwt.InvalidSignatureError: + raise Unauthorized("Invalid signature") + except jwt.InvalidTokenError: + raise Unauthorized("Invalid token") + except Exception as e: + sentry_sdk.capture_exception(e) + print(e) + raise InternalServerError("Something went wrong with the Bearer token auth") + elif not config.IGNORE_API_AUTH and config.is_production: + logger.warning(f"Found unexpected authorization header: {auth_header}") + raise Unauthorized( + "Neither Rpcsignature nor a Bearer token was included in authorization header!" + ) # Cached from ^^, this won't result in double read. data = request.get_json() diff --git a/tests/test_json_api.py b/tests/test_json_api.py index e5af1d2b..359f4d1c 100644 --- a/tests/test_json_api.py +++ b/tests/test_json_api.py @@ -1,10 +1,16 @@ +from unittest.mock import patch + +import jwt +import pytest +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa from flask import Blueprint, Flask from johen import change_watcher from pydantic import BaseModel from seer.configuration import AppConfig -from seer.dependency_injection import Module -from seer.json_api import json_api +from seer.dependency_injection import Module, resolve +from seer.json_api import PublicKeyBytes, json_api class DummyRequest(BaseModel): @@ -29,6 +35,9 @@ def my_endpoint(request: DummyRequest) -> DummyResponse: app.register_blueprint(blueprint) + app_config = resolve(AppConfig) + app_config.IGNORE_API_AUTH = True + response = test_client.post("/v0/some/url", json={"thing": "thing", "b": 12}) assert response.status_code == 200 assert response.get_json() == {"blah": "do it"} @@ -36,6 +45,146 @@ def my_endpoint(request: DummyRequest) -> DummyResponse: assert my_endpoint(DummyRequest(thing="thing", b=12)) == DummyResponse(blah="do it") +def test_json_api_bearer_token_auth(): + app = Flask(__name__) + blueprint = Blueprint("blueprint", __name__) + test_client = app.test_client() + + @json_api(blueprint, "/v0/some/url") + def my_endpoint(request: DummyRequest) -> DummyResponse: + return DummyResponse(blah="do it") + + app.register_blueprint(blueprint) + + app_config = resolve(AppConfig) + app_config.IGNORE_API_AUTH = False + app_config.DEV = False + + pk = resolve(PublicKeyBytes) + pk.bytes = b"mock_public_key" + + with patch("seer.json_api.jwt.decode") as mock_jwt_decode: + # Test valid token + headers = {"Authorization": "Bearer valid_token"} + response = test_client.post( + "/v0/some/url", json={"thing": "thing", "b": 12}, headers=headers + ) + assert response.status_code == 200 + mock_jwt_decode.assert_called_once_with( + "valid_token", b"mock_public_key", algorithms=["RS256"] + ) + + # Test invalid token + mock_jwt_decode.side_effect = jwt.InvalidTokenError + response = test_client.post( + "/v0/some/url", json={"thing": "thing", "b": 12}, headers=headers + ) + assert response.status_code == 401 + assert b"Invalid token" in response.data + + # Test missing Authorization header + response = test_client.post("/v0/some/url", json={"thing": "thing", "b": 12}) + assert response.status_code == 401 + assert ( + b"Neither Rpcsignature nor a Bearer token was included in authorization header!" + in response.data + ) + + # Test incorrect Authorization header format + headers = {"Authorization": "InvalidFormat token"} + response = test_client.post( + "/v0/some/url", json={"thing": "thing", "b": 12}, headers=headers + ) + assert response.status_code == 401 + assert ( + b"Neither Rpcsignature nor a Bearer token was included in authorization header!" + in response.data + ) + + +def test_json_api_auth_not_enforced(): + app = Flask(__name__) + blueprint = Blueprint("blueprint", __name__) + test_client = app.test_client() + + @json_api(blueprint, "/v0/some/url") + def my_endpoint(request: DummyRequest) -> DummyResponse: + return DummyResponse(blah="do it") + + app.register_blueprint(blueprint) + + app_config = resolve(AppConfig) + app_config.IGNORE_API_AUTH = False + + # Test that request is allowed without any auth when ENFORCE_API_AUTH is False + response = test_client.post("/v0/some/url", json={"thing": "thing", "b": 12}) + assert response.status_code == 200 + assert response.get_json() == {"blah": "do it"} + + +def test_json_api_auth_with_real_jwt(): + + app_config = resolve(AppConfig) + app_config.IGNORE_API_AUTH = False + app_config.DEV = False + + # Generate a test RSA key pair + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + public_key = private_key.public_key() + + # Convert public key to PEM format + public_key_pem = public_key.public_bytes( + encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + + # Create a test JWT token + payload = {"sub": "1234567890", "name": "Test User", "iat": 1516239022} + token = jwt.encode(payload, private_key, algorithm="RS256") + + module = Module() + module.constant(PublicKeyBytes, PublicKeyBytes(bytes=public_key_pem)) + with module: + app = Flask(__name__) + blueprint = Blueprint("blueprint", __name__) + test_client = app.test_client() + + @json_api(blueprint, "/v0/some/url") + def my_endpoint(request: DummyRequest) -> DummyResponse: + return DummyResponse(blah="do it") + + app.register_blueprint(blueprint) + + # Test valid token + headers = {"Authorization": f"Bearer {token}"} + response = test_client.post( + "/v0/some/url", json={"thing": "thing", "b": 12}, headers=headers + ) + assert response.status_code == 200 + assert response.get_json() == {"blah": "do it"} + + # Test invalid token + invalid_token = jwt.encode(payload, "wrong_key", algorithm="HS256") + headers = {"Authorization": f"Bearer {invalid_token}"} + response = test_client.post( + "/v0/some/url", json={"thing": "thing", "b": 12}, headers=headers + ) + assert response.status_code == 401 + assert b"Invalid token" in response.data + + # Test expired token + import time + + expired_payload = {"exp": int(time.time()) - 300} # Token expired 5 minutes ago + expired_token = jwt.encode(expired_payload, private_key, algorithm="RS256") + headers = {"Authorization": f"Bearer {expired_token}"} + response = test_client.post( + "/v0/some/url", json={"thing": "thing", "b": 12}, headers=headers + ) + assert response.status_code == 401 + assert b"Token has expired" in response.data + + +@pytest.mark.skip(reason="Waiting to validate configuration in production") def test_json_api_signature_strict_mode(): app = Flask(__name__) blueprint = Blueprint("blueprint", __name__)