Skip to content

Commit

Permalink
Basic multi tenant api key (#3004)
Browse files Browse the repository at this point in the history
* basic multi tenant api key

* organization

* nit

* clean
  • Loading branch information
pablodanswer authored Nov 1, 2024
1 parent 6d543f3 commit 753293c
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 34 deletions.
33 changes: 31 additions & 2 deletions backend/ee/danswer/auth/api_key.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import secrets
import uuid
from urllib.parse import quote
from urllib.parse import unquote

from fastapi import Request
from passlib.hash import sha256_crypt
Expand Down Expand Up @@ -30,8 +32,35 @@ class ApiKeyDescriptor(BaseModel):
user_id: uuid.UUID


def generate_api_key() -> str:
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
def generate_api_key(tenant_id: str | None = None) -> str:
# For backwards compatibility, if no tenant_id, generate old style key
if not tenant_id:
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)

encoded_tenant = quote(tenant_id) # URL encode the tenant ID
return f"{_API_KEY_PREFIX}{encoded_tenant}.{secrets.token_urlsafe(_API_KEY_LEN)}"


def extract_tenant_from_api_key_header(request: Request) -> str | None:
"""Extract tenant ID from request. Returns None if auth is disabled or invalid format."""
raw_api_key_header = request.headers.get(
_API_KEY_HEADER_ALTERNATIVE_NAME
) or request.headers.get(_API_KEY_HEADER_NAME)

if not raw_api_key_header or not raw_api_key_header.startswith(_BEARER_PREFIX):
return None

api_key = raw_api_key_header[len(_BEARER_PREFIX) :].strip()

if not api_key.startswith(_API_KEY_PREFIX):
return None

parts = api_key[len(_API_KEY_PREFIX) :].split(".", 1)
if len(parts) != 2:
return None

tenant_id = parts[0]
return unquote(tenant_id) if tenant_id else None


def hash_api_key(api_key: str) -> str:
Expand Down
8 changes: 7 additions & 1 deletion backend/ee/danswer/db/api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from ee.danswer.auth.api_key import generate_api_key
from ee.danswer.auth.api_key import hash_api_key
from ee.danswer.server.api_key.models import APIKeyArgs
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT


def get_api_key_email_pattern() -> str:
Expand Down Expand Up @@ -64,7 +66,11 @@ def insert_api_key(
db_session: Session, api_key_args: APIKeyArgs, user_id: uuid.UUID | None
) -> ApiKeyDescriptor:
std_password_helper = PasswordHelper()
api_key = generate_api_key()

# Get tenant_id from context var (will be default schema for single tenant)
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()

api_key = generate_api_key(tenant_id if MULTI_TENANT else None)
api_key_user_id = uuid.uuid4()

display_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
Expand Down
77 changes: 46 additions & 31 deletions backend/ee/danswer/server/middleware/tenant_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.db.engine import is_valid_schema_name
from ee.danswer.auth.api_key import extract_tenant_from_api_key_header
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
Expand All @@ -21,40 +22,54 @@ async def set_tenant_id(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
try:
if not MULTI_TENANT:
tenant_id = POSTGRES_DEFAULT_SCHEMA
else:
token = request.cookies.get("fastapiusersauth")

if token:
try:
payload = jwt.decode(
token,
USER_AUTH_SECRET,
audience=["fastapi-users:auth"],
algorithms=["HS256"],
)
tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
if not is_valid_schema_name(tenant_id):
raise HTTPException(
status_code=400, detail="Invalid tenant ID format"
)
except jwt.InvalidTokenError:
tenant_id = POSTGRES_DEFAULT_SCHEMA
except Exception as e:
logger.error(
f"Unexpected error in set_tenant_id_middleware: {str(e)}"
)
raise HTTPException(
status_code=500, detail="Internal server error"
)
else:
tenant_id = POSTGRES_DEFAULT_SCHEMA
tenant_id = POSTGRES_DEFAULT_SCHEMA

if MULTI_TENANT:
tenant_id = _get_tenant_id_from_request(request, logger)

CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
response = await call_next(request)
return response
return await call_next(request)

except Exception as e:
logger.error(f"Error in tenant ID middleware: {str(e)}")
raise


def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter) -> str:
# First check for API key
tenant_id = extract_tenant_from_api_key_header(request)
if tenant_id is not None:
return tenant_id

# Check for cookie-based auth
token = request.cookies.get("fastapiusersauth")
if not token:
return POSTGRES_DEFAULT_SCHEMA

try:
payload = jwt.decode(
token,
USER_AUTH_SECRET,
audience=["fastapi-users:auth"],
algorithms=["HS256"],
)
tenant_id_from_payload = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)

# Since payload.get() can return None, ensure we have a string
tenant_id = (
str(tenant_id_from_payload)
if tenant_id_from_payload is not None
else POSTGRES_DEFAULT_SCHEMA
)

if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")

return tenant_id

except jwt.InvalidTokenError:
return POSTGRES_DEFAULT_SCHEMA

except Exception as e:
logger.error(f"Unexpected error in set_tenant_id_middleware: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")

0 comments on commit 753293c

Please sign in to comment.