Skip to content

Commit

Permalink
use dependency injection for the public key
Browse files Browse the repository at this point in the history
  • Loading branch information
jennmueng committed Aug 19, 2024
1 parent e0f5302 commit 98c271b
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 13 deletions.
45 changes: 35 additions & 10 deletions src/seer/json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pydantic import BaseModel, ValidationError
from werkzeug.exceptions import BadRequest, Unauthorized

from seer.bootup import module, stub_module
from seer.configuration import AppConfig
from seer.dependency_injection import inject, injected

Expand All @@ -29,12 +30,41 @@ def access_secret(project_id: str, secret_id: str, version_id: str = "latest"):
def get_public_key_from_secret(project_id: str, secret_id: str, version_id: str = "latest"):
pem_data = access_secret(project_id, secret_id, version_id)
public_key = serialization.load_pem_public_key(pem_data.encode(), backend=default_backend())
return public_key
public_key_bytes = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)

return public_key_bytes


class PublicKeyBytes(BaseModel):
bytes: bytes | None


@module.provider
def provide_public_key(config: AppConfig = injected) -> PublicKeyBytes: # type: ignore
return PublicKeyBytes(
bytes=(
get_public_key_from_secret(
config.GOOGLE_CLOUD_PROJECT_ID, config.API_PUBLIC_KEY_SECRET_ID
)
if config.GOOGLE_CLOUD_PROJECT_ID and config.API_PUBLIC_KEY_SECRET_ID
else None
)
)


@stub_module.provider
def provide_public_key_stub() -> PublicKeyBytes: # type: ignore
return PublicKeyBytes(bytes=None)


def json_api(blueprint: Blueprint, url_rule: str) -> Callable[[_F], _F]:
@inject
def decorator(implementation: _F, config: AppConfig = injected) -> _F:
def decorator(
implementation: _F, config: AppConfig = injected, public_key: PublicKeyBytes = injected
) -> _F:
spec = inspect.getfullargspec(implementation)
annotations = get_type_hints(implementation)
try:
Expand All @@ -61,15 +91,10 @@ def wrapper() -> Any:
token = auth_header.split()[1]
try:
try:
if public_key.bytes is None:
raise Unauthorized("Public key is not available")
# Verify the JWT token using PyJWT
public_key = get_public_key_from_secret(
config.GOOGLE_CLOUD_PROJECT_ID, config.API_PUBLIC_KEY_SECRET_ID
)
public_key_bytes = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
jwt.decode(token, public_key_bytes, algorithms=["RS256"])
jwt.decode(token, public_key.bytes, algorithms=["RS256"])

# Optionally, you can add additional checks here
# For example, checking the 'exp' claim for token expiration
Expand Down
11 changes: 8 additions & 3 deletions tests/test_json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from seer.configuration import AppConfig
from seer.dependency_injection import resolve
from seer.json_api import json_api
from seer.json_api import PublicKeyBytes, json_api


class DummyRequest(BaseModel):
Expand All @@ -34,6 +34,9 @@ def my_endpoint(request: DummyRequest) -> DummyResponse:

app.register_blueprint(blueprint)

app_config = resolve(AppConfig)
app_config.ENFORCE_API_AUTH = False

response = test_client.post("/v0/some/url", json={"thing": "thing", "b": 12})
assert response.status_code == 200
assert response.get_json() == {"blah": "do it"}
Expand Down Expand Up @@ -84,7 +87,9 @@ def my_endpoint(request: DummyRequest) -> DummyResponse:

app_config = resolve(AppConfig)
app_config.ENFORCE_API_AUTH = True
app_config.API_PUBLIC_KEY = "mock_public_key"

pk = resolve(PublicKeyBytes)
pk.bytes = b"mock_public_key"

with patch("seer.json_api.jwt.decode") as mock_jwt_decode:
# Test valid token
Expand All @@ -94,7 +99,7 @@ def my_endpoint(request: DummyRequest) -> DummyResponse:
)
assert response.status_code == 200
mock_jwt_decode.assert_called_once_with(
"valid_token", "mock_public_key", algorithms=["RS256"]
"valid_token", b"mock_public_key", algorithms=["RS256"]
)

# Test invalid token
Expand Down

0 comments on commit 98c271b

Please sign in to comment.