Skip to content

Commit

Permalink
feat(security): Add API bearer token auth via JWT (#1063)
Browse files Browse the repository at this point in the history
Mainly for the staging deployment in `ml-ai`

Introduces an `ENFORCE_API_AUTH` env variable that will enforce either
the rpc secret signing or the bearer token.
  • Loading branch information
jennmueng committed Sep 4, 2024
1 parent 7986fce commit 9ef0b6d
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 8 deletions.
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ services:
- GOOGLE_APPLICATION_CREDENTIALS=/root/.config/gcloud/application_default_credentials.json
- GOOGLE_CLOUD_PROJECT_ID=ml-ai-420606
- USE_EU_REGION=0
- IGNORE_API_AUTH=1
ports:
- "9091:9091" # Local dev sentry app looks for port 9091 for the seer service.
volumes:
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,10 @@ pydantic-xml==2.9.0
chromadb==0.4.14
google-cloud-storage==2.16.0
google-cloud-aiplatform==1.60.0
google-cloud-secret-manager==2.20.2
anthropic[vertex]==0.31.2
langfuse @ git+https://github.com/jennmueng/langfuse-python.git@9d9350de1e4e84fa548fe84f82c1b826be17956e
watchdog
stumpy==1.13.0
pytest_alembic==0.11.1
cryptography==43.0.0
10 changes: 10 additions & 0 deletions src/seer/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ class AppConfig(BaseModel):
LANGFUSE_SECRET_KEY: str = ""
LANGFUSE_HOST: str = ""

API_PUBLIC_KEY_SECRET_ID: str = ""
JSON_API_SHARED_SECRETS: ParseList = Field(default_factory=list)
IGNORE_API_AUTH: ParseBool = False # Used for both API Tokens and RPC Secrets

GOOGLE_CLOUD_PROJECT_ID: str = ""

TORCH_NUM_THREADS: ParseInt = 0
NO_SENTRY_INTEGRATION: ParseBool = False
DEV: ParseBool = False
Expand All @@ -74,6 +79,11 @@ def has_sentry_integration(self) -> bool:
return not self.NO_SENTRY_INTEGRATION

def do_validation(self):
if not self.IGNORE_API_AUTH:
assert (
self.JSON_API_SHARED_SECRETS or self.API_PUBLIC_KEY_SECRET_ID
), "JSON_API_SHARED_SECRETS or API_PUBLIC_KEY_SECRET_ID required if IGNORE_API_AUTH is false!"

if self.is_production:
assert self.has_sentry_integration, "Sentry integration required for production mode."
assert self.SENTRY_DSN, "SENTRY_DSN required for production!"
Expand Down
86 changes: 80 additions & 6 deletions src/seer/json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,70 @@
import logging
from typing import Any, Callable, Type, TypeVar, get_type_hints

import jwt
import sentry_sdk
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from flask import Blueprint, request
from google.cloud import secretmanager
from pydantic import BaseModel, ValidationError
from werkzeug.exceptions import BadRequest, Unauthorized
from werkzeug.exceptions import BadRequest, InternalServerError, Unauthorized

from seer.bootup import module, stub_module
from seer.configuration import AppConfig
from seer.dependency_injection import inject, injected

logger = logging.getLogger(__name__)


_F = TypeVar("_F", bound=Callable[..., Any])


def access_secret(project_id: str, secret_id: str, version_id: str = "latest"):
client = secretmanager.SecretManagerServiceClient()
name = f"projects/{project_id}/secrets/{secret_id}/versions/{version_id}"
response = client.access_secret_version(request={"name": name})
return response.payload.data.decode("UTF-8")


def get_public_key_from_secret(project_id: str, secret_id: str, version_id: str = "latest"):
pem_data = access_secret(project_id, secret_id, version_id)
public_key = serialization.load_pem_public_key(pem_data.encode(), backend=default_backend())
public_key_bytes = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)

return public_key_bytes


class PublicKeyBytes(BaseModel):
bytes: bytes | None


@module.provider
def provide_public_key(config: AppConfig = injected) -> PublicKeyBytes:
return PublicKeyBytes(
bytes=(
get_public_key_from_secret(
config.GOOGLE_CLOUD_PROJECT_ID, config.API_PUBLIC_KEY_SECRET_ID
)
if config.GOOGLE_CLOUD_PROJECT_ID and config.API_PUBLIC_KEY_SECRET_ID
else None
)
)


@stub_module.provider
def provide_public_key_stub() -> PublicKeyBytes:
return PublicKeyBytes(bytes=None)


def json_api(blueprint: Blueprint, url_rule: str) -> Callable[[_F], _F]:
def decorator(implementation: _F) -> _F:
@inject
def decorator(
implementation: _F, config: AppConfig = injected, public_key: PublicKeyBytes = injected
) -> _F:
spec = inspect.getfullargspec(implementation)
annotations = get_type_hints(implementation)
try:
Expand All @@ -43,10 +92,35 @@ def wrapper(config: AppConfig = injected) -> Any:
parts = auth_header.split()
if len(parts) != 2 or not compare_signature(request.url, raw_data, parts[1]):
raise Unauthorized("Rpcsignature did not match for given url and data")
else:
if config.is_production:
logger.warning(f"Found unexpected authorization header: {auth_header}")
raise Unauthorized("Rpcsignature was not included in authorization header!")
elif auth_header.startswith("Bearer "):
token = auth_header.split()[1]
try:
if public_key.bytes is None:
raise Unauthorized("Public key is not available")
# Verify the JWT token using PyJWT
jwt.decode(token, public_key.bytes, algorithms=["RS256"])

# Optionally, you can add additional checks here
# For example, checking the 'exp' claim for token expiration
# or verifying specific claims in the token payload

# If the token is successfully decoded and verified,
# the function will continue execution
except jwt.ExpiredSignatureError:
raise Unauthorized("Token has expired")
except jwt.InvalidSignatureError:
raise Unauthorized("Invalid signature")
except jwt.InvalidTokenError:
raise Unauthorized("Invalid token")
except Exception as e:
sentry_sdk.capture_exception(e)
print(e)
raise InternalServerError("Something went wrong with the Bearer token auth")
elif not config.IGNORE_API_AUTH and config.is_production:
logger.warning(f"Found unexpected authorization header: {auth_header}")
raise Unauthorized(
"Neither Rpcsignature nor a Bearer token was included in authorization header!"
)

# Cached from ^^, this won't result in double read.
data = request.get_json()
Expand Down
153 changes: 151 additions & 2 deletions tests/test_json_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from unittest.mock import patch

import jwt
import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from flask import Blueprint, Flask
from johen import change_watcher
from pydantic import BaseModel

from seer.configuration import AppConfig
from seer.dependency_injection import Module
from seer.json_api import json_api
from seer.dependency_injection import Module, resolve
from seer.json_api import PublicKeyBytes, json_api


class DummyRequest(BaseModel):
Expand All @@ -29,13 +35,156 @@ def my_endpoint(request: DummyRequest) -> DummyResponse:

app.register_blueprint(blueprint)

app_config = resolve(AppConfig)
app_config.IGNORE_API_AUTH = True

response = test_client.post("/v0/some/url", json={"thing": "thing", "b": 12})
assert response.status_code == 200
assert response.get_json() == {"blah": "do it"}

assert my_endpoint(DummyRequest(thing="thing", b=12)) == DummyResponse(blah="do it")


def test_json_api_bearer_token_auth():
app = Flask(__name__)
blueprint = Blueprint("blueprint", __name__)
test_client = app.test_client()

@json_api(blueprint, "/v0/some/url")
def my_endpoint(request: DummyRequest) -> DummyResponse:
return DummyResponse(blah="do it")

app.register_blueprint(blueprint)

app_config = resolve(AppConfig)
app_config.IGNORE_API_AUTH = False
app_config.DEV = False

pk = resolve(PublicKeyBytes)
pk.bytes = b"mock_public_key"

with patch("seer.json_api.jwt.decode") as mock_jwt_decode:
# Test valid token
headers = {"Authorization": "Bearer valid_token"}
response = test_client.post(
"/v0/some/url", json={"thing": "thing", "b": 12}, headers=headers
)
assert response.status_code == 200
mock_jwt_decode.assert_called_once_with(
"valid_token", b"mock_public_key", algorithms=["RS256"]
)

# Test invalid token
mock_jwt_decode.side_effect = jwt.InvalidTokenError
response = test_client.post(
"/v0/some/url", json={"thing": "thing", "b": 12}, headers=headers
)
assert response.status_code == 401
assert b"Invalid token" in response.data

# Test missing Authorization header
response = test_client.post("/v0/some/url", json={"thing": "thing", "b": 12})
assert response.status_code == 401
assert (
b"Neither Rpcsignature nor a Bearer token was included in authorization header!"
in response.data
)

# Test incorrect Authorization header format
headers = {"Authorization": "InvalidFormat token"}
response = test_client.post(
"/v0/some/url", json={"thing": "thing", "b": 12}, headers=headers
)
assert response.status_code == 401
assert (
b"Neither Rpcsignature nor a Bearer token was included in authorization header!"
in response.data
)


def test_json_api_auth_not_enforced():
app = Flask(__name__)
blueprint = Blueprint("blueprint", __name__)
test_client = app.test_client()

@json_api(blueprint, "/v0/some/url")
def my_endpoint(request: DummyRequest) -> DummyResponse:
return DummyResponse(blah="do it")

app.register_blueprint(blueprint)

app_config = resolve(AppConfig)
app_config.IGNORE_API_AUTH = False

# Test that request is allowed without any auth when ENFORCE_API_AUTH is False
response = test_client.post("/v0/some/url", json={"thing": "thing", "b": 12})
assert response.status_code == 200
assert response.get_json() == {"blah": "do it"}


def test_json_api_auth_with_real_jwt():

app_config = resolve(AppConfig)
app_config.IGNORE_API_AUTH = False
app_config.DEV = False

# Generate a test RSA key pair
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
public_key = private_key.public_key()

# Convert public key to PEM format
public_key_pem = public_key.public_bytes(
encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo
)

# Create a test JWT token
payload = {"sub": "1234567890", "name": "Test User", "iat": 1516239022}
token = jwt.encode(payload, private_key, algorithm="RS256")

module = Module()
module.constant(PublicKeyBytes, PublicKeyBytes(bytes=public_key_pem))
with module:
app = Flask(__name__)
blueprint = Blueprint("blueprint", __name__)
test_client = app.test_client()

@json_api(blueprint, "/v0/some/url")
def my_endpoint(request: DummyRequest) -> DummyResponse:
return DummyResponse(blah="do it")

app.register_blueprint(blueprint)

# Test valid token
headers = {"Authorization": f"Bearer {token}"}
response = test_client.post(
"/v0/some/url", json={"thing": "thing", "b": 12}, headers=headers
)
assert response.status_code == 200
assert response.get_json() == {"blah": "do it"}

# Test invalid token
invalid_token = jwt.encode(payload, "wrong_key", algorithm="HS256")
headers = {"Authorization": f"Bearer {invalid_token}"}
response = test_client.post(
"/v0/some/url", json={"thing": "thing", "b": 12}, headers=headers
)
assert response.status_code == 401
assert b"Invalid token" in response.data

# Test expired token
import time

expired_payload = {"exp": int(time.time()) - 300} # Token expired 5 minutes ago
expired_token = jwt.encode(expired_payload, private_key, algorithm="RS256")
headers = {"Authorization": f"Bearer {expired_token}"}
response = test_client.post(
"/v0/some/url", json={"thing": "thing", "b": 12}, headers=headers
)
assert response.status_code == 401
assert b"Token has expired" in response.data


@pytest.mark.skip(reason="Waiting to validate configuration in production")
def test_json_api_signature_strict_mode():
app = Flask(__name__)
blueprint = Blueprint("blueprint", __name__)
Expand Down

0 comments on commit 9ef0b6d

Please sign in to comment.