Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(security): Add API bearer token auth via JWT #1063

Merged
merged 5 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,10 @@ pydantic-xml==2.9.0
chromadb==0.4.14
google-cloud-storage==2.16.0
google-cloud-aiplatform==1.60.0
google-cloud-secret-manager==2.20.2
anthropic[vertex]==0.31.2
langfuse @ git+https://github.com/jennmueng/langfuse-python.git@9d9350de1e4e84fa548fe84f82c1b826be17956e
watchdog
stumpy==1.12.0
pytest_alembic==0.11.1
cryptography==43.0.0
10 changes: 10 additions & 0 deletions src/seer/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ class AppConfig(BaseModel):
LANGFUSE_SECRET_KEY: str = ""
LANGFUSE_HOST: str = ""

API_PUBLIC_KEY_SECRET_ID: str = ""
JSON_API_SHARED_SECRETS: ParseList = Field(default_factory=list)
IGNORE_API_AUTH: ParseBool = False # Used for both API Tokens and RPC Secrets

GOOGLE_CLOUD_PROJECT_ID: str = ""

TORCH_NUM_THREADS: ParseInt = 0
NO_SENTRY_INTEGRATION: ParseBool = False
DEV: ParseBool = False
Expand All @@ -74,6 +79,11 @@ def has_sentry_integration(self) -> bool:
return not self.NO_SENTRY_INTEGRATION

def do_validation(self):
if self.IGNORE_API_AUTH:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be if not self.IGNORE_API_AUTH?

assert (
self.JSON_API_SHARED_SECRETS or self.API_PUBLIC_KEY_SECRET_ID
), "JSON_API_SHARED_SECRETS or API_PUBLIC_KEY_SECRET_ID required if IGNORE_API_AUTH is true!"

if self.is_production:
assert self.has_sentry_integration, "Sentry integration required for production mode."
assert self.SENTRY_DSN, "SENTRY_DSN required for production!"
Expand Down
86 changes: 80 additions & 6 deletions src/seer/json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,70 @@
import logging
from typing import Any, Callable, Type, TypeVar, get_type_hints

import jwt
import sentry_sdk
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from flask import Blueprint, request
from google.cloud import secretmanager
from pydantic import BaseModel, ValidationError
from werkzeug.exceptions import BadRequest, Unauthorized
from werkzeug.exceptions import BadRequest, InternalServerError, Unauthorized

from seer.bootup import module, stub_module
from seer.configuration import AppConfig
from seer.dependency_injection import inject, injected

logger = logging.getLogger(__name__)


_F = TypeVar("_F", bound=Callable[..., Any])


def access_secret(project_id: str, secret_id: str, version_id: str = "latest"):
client = secretmanager.SecretManagerServiceClient()
name = f"projects/{project_id}/secrets/{secret_id}/versions/{version_id}"
response = client.access_secret_version(request={"name": name})
return response.payload.data.decode("UTF-8")


def get_public_key_from_secret(project_id: str, secret_id: str, version_id: str = "latest"):
pem_data = access_secret(project_id, secret_id, version_id)
public_key = serialization.load_pem_public_key(pem_data.encode(), backend=default_backend())
public_key_bytes = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)

return public_key_bytes


class PublicKeyBytes(BaseModel):
bytes: bytes | None


@module.provider
def provide_public_key(config: AppConfig = injected) -> PublicKeyBytes:
jennmueng marked this conversation as resolved.
Show resolved Hide resolved
return PublicKeyBytes(
bytes=(
get_public_key_from_secret(
config.GOOGLE_CLOUD_PROJECT_ID, config.API_PUBLIC_KEY_SECRET_ID
)
if config.GOOGLE_CLOUD_PROJECT_ID and config.API_PUBLIC_KEY_SECRET_ID
else None
)
)


@stub_module.provider
def provide_public_key_stub() -> PublicKeyBytes:
return PublicKeyBytes(bytes=None)


def json_api(blueprint: Blueprint, url_rule: str) -> Callable[[_F], _F]:
def decorator(implementation: _F) -> _F:
@inject
def decorator(
implementation: _F, config: AppConfig = injected, public_key: PublicKeyBytes = injected
) -> _F:
spec = inspect.getfullargspec(implementation)
annotations = get_type_hints(implementation)
try:
Expand All @@ -43,10 +92,35 @@ def wrapper(config: AppConfig = injected) -> Any:
parts = auth_header.split()
if len(parts) != 2 or not compare_signature(request.url, raw_data, parts[1]):
raise Unauthorized("Rpcsignature did not match for given url and data")
else:
if config.is_production:
logger.warning(f"Found unexpected authorization header: {auth_header}")
raise Unauthorized("Rpcsignature was not included in authorization header!")
elif auth_header.startswith("Bearer "):
token = auth_header.split()[1]
try:
if public_key.bytes is None:
raise Unauthorized("Public key is not available")
# Verify the JWT token using PyJWT
jwt.decode(token, public_key.bytes, algorithms=["RS256"])

# Optionally, you can add additional checks here
# For example, checking the 'exp' claim for token expiration
# or verifying specific claims in the token payload

# If the token is successfully decoded and verified,
# the function will continue execution
except jwt.ExpiredSignatureError:
raise Unauthorized("Token has expired")
except jwt.InvalidSignatureError:
raise Unauthorized("Invalid signature")
except jwt.InvalidTokenError:
raise Unauthorized("Invalid token")
except Exception as e:
sentry_sdk.capture_exception(e)
jennmueng marked this conversation as resolved.
Show resolved Hide resolved
print(e)
raise InternalServerError("Something went wrong with the Bearer token auth")
elif not config.IGNORE_API_AUTH and config.is_production:
logger.warning(f"Found unexpected authorization header: {auth_header}")
raise Unauthorized(
"Neither Rpcsignature nor a Bearer token was included in authorization header!"
)

# Cached from ^^, this won't result in double read.
data = request.get_json()
Expand Down
153 changes: 151 additions & 2 deletions tests/test_json_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from unittest.mock import patch

import jwt
import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from flask import Blueprint, Flask
from johen import change_watcher
from pydantic import BaseModel

from seer.configuration import AppConfig
from seer.dependency_injection import Module
from seer.json_api import json_api
from seer.dependency_injection import Module, resolve
from seer.json_api import PublicKeyBytes, json_api


class DummyRequest(BaseModel):
Expand All @@ -29,13 +35,156 @@ def my_endpoint(request: DummyRequest) -> DummyResponse:

app.register_blueprint(blueprint)

app_config = resolve(AppConfig)
app_config.IGNORE_API_AUTH = True

response = test_client.post("/v0/some/url", json={"thing": "thing", "b": 12})
assert response.status_code == 200
assert response.get_json() == {"blah": "do it"}

assert my_endpoint(DummyRequest(thing="thing", b=12)) == DummyResponse(blah="do it")


def test_json_api_bearer_token_auth():
app = Flask(__name__)
blueprint = Blueprint("blueprint", __name__)
test_client = app.test_client()

@json_api(blueprint, "/v0/some/url")
def my_endpoint(request: DummyRequest) -> DummyResponse:
return DummyResponse(blah="do it")

app.register_blueprint(blueprint)

app_config = resolve(AppConfig)
app_config.IGNORE_API_AUTH = False
app_config.DEV = False

pk = resolve(PublicKeyBytes)
pk.bytes = b"mock_public_key"

with patch("seer.json_api.jwt.decode") as mock_jwt_decode:
jennmueng marked this conversation as resolved.
Show resolved Hide resolved
# Test valid token
headers = {"Authorization": "Bearer valid_token"}
response = test_client.post(
"/v0/some/url", json={"thing": "thing", "b": 12}, headers=headers
)
assert response.status_code == 200
mock_jwt_decode.assert_called_once_with(
"valid_token", b"mock_public_key", algorithms=["RS256"]
)

# Test invalid token
mock_jwt_decode.side_effect = jwt.InvalidTokenError
response = test_client.post(
"/v0/some/url", json={"thing": "thing", "b": 12}, headers=headers
)
assert response.status_code == 401
assert b"Invalid token" in response.data

# Test missing Authorization header
response = test_client.post("/v0/some/url", json={"thing": "thing", "b": 12})
assert response.status_code == 401
assert (
b"Neither Rpcsignature nor a Bearer token was included in authorization header!"
in response.data
)

# Test incorrect Authorization header format
headers = {"Authorization": "InvalidFormat token"}
response = test_client.post(
"/v0/some/url", json={"thing": "thing", "b": 12}, headers=headers
)
assert response.status_code == 401
assert (
b"Neither Rpcsignature nor a Bearer token was included in authorization header!"
in response.data
)


def test_json_api_auth_not_enforced():
app = Flask(__name__)
blueprint = Blueprint("blueprint", __name__)
test_client = app.test_client()

@json_api(blueprint, "/v0/some/url")
def my_endpoint(request: DummyRequest) -> DummyResponse:
return DummyResponse(blah="do it")

app.register_blueprint(blueprint)

app_config = resolve(AppConfig)
app_config.IGNORE_API_AUTH = False

# Test that request is allowed without any auth when ENFORCE_API_AUTH is False
response = test_client.post("/v0/some/url", json={"thing": "thing", "b": 12})
assert response.status_code == 200
assert response.get_json() == {"blah": "do it"}


def test_json_api_auth_with_real_jwt():

app_config = resolve(AppConfig)
app_config.IGNORE_API_AUTH = False
app_config.DEV = False

# Generate a test RSA key pair
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
public_key = private_key.public_key()

# Convert public key to PEM format
public_key_pem = public_key.public_bytes(
encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo
)

# Create a test JWT token
payload = {"sub": "1234567890", "name": "Test User", "iat": 1516239022}
token = jwt.encode(payload, private_key, algorithm="RS256")

module = Module()
module.constant(PublicKeyBytes, PublicKeyBytes(bytes=public_key_pem))
with module:
app = Flask(__name__)
blueprint = Blueprint("blueprint", __name__)
test_client = app.test_client()

@json_api(blueprint, "/v0/some/url")
def my_endpoint(request: DummyRequest) -> DummyResponse:
return DummyResponse(blah="do it")

app.register_blueprint(blueprint)

# Test valid token
headers = {"Authorization": f"Bearer {token}"}
response = test_client.post(
"/v0/some/url", json={"thing": "thing", "b": 12}, headers=headers
)
assert response.status_code == 200
assert response.get_json() == {"blah": "do it"}

# Test invalid token
invalid_token = jwt.encode(payload, "wrong_key", algorithm="HS256")
headers = {"Authorization": f"Bearer {invalid_token}"}
response = test_client.post(
"/v0/some/url", json={"thing": "thing", "b": 12}, headers=headers
)
assert response.status_code == 401
assert b"Invalid token" in response.data

# Test expired token
import time

expired_payload = {"exp": int(time.time()) - 300} # Token expired 5 minutes ago
expired_token = jwt.encode(expired_payload, private_key, algorithm="RS256")
headers = {"Authorization": f"Bearer {expired_token}"}
response = test_client.post(
"/v0/some/url", json={"thing": "thing", "b": 12}, headers=headers
)
assert response.status_code == 401
assert b"Token has expired" in response.data


@pytest.mark.skip(reason="Waiting to validate configuration in production")
def test_json_api_signature_strict_mode():
app = Flask(__name__)
blueprint = Blueprint("blueprint", __name__)
Expand Down
Loading