Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
pablodanswer committed Oct 31, 2024
1 parent 3fb6e9b commit 67e347a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 14 deletions.
22 changes: 10 additions & 12 deletions backend/ee/danswer/auth/api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,16 @@ def generate_api_key(tenant_id: str | None = None) -> str:
return f"{_API_KEY_PREFIX}{encoded_tenant}.{secrets.token_urlsafe(_API_KEY_LEN)}"


def extract_tenant_from_api_key(request: Request) -> str | None:
"""Extract tenant ID from request. Returns None if auth is disabled."""
def extract_tenant_from_api_key_header(request: Request) -> str | None:
"""Extract tenant ID from request. Returns None if auth is disabled or invalid format."""
raw_api_key_header = request.headers.get(
_API_KEY_HEADER_ALTERNATIVE_NAME
) or request.headers.get(_API_KEY_HEADER_NAME)

api_key_header = request.headers.get("Authorization")
tenant_id = None
if not api_key_header or not api_key_header.startswith("Bearer "):
if not raw_api_key_header or not raw_api_key_header.startswith(_BEARER_PREFIX):
return None

api_key = api_key_header[7:] # Remove "Bearer " prefix
api_key = raw_api_key_header[len(_BEARER_PREFIX) :].strip()

if not api_key.startswith(_API_KEY_PREFIX):
return None
Expand All @@ -58,15 +59,12 @@ def extract_tenant_from_api_key(request: Request) -> str | None:
if len(parts) != 2:
return None

tenant_id, _ = parts
if not tenant_id:
return None

return unquote(tenant_id)
tenant_id = parts[0]
return unquote(tenant_id) if tenant_id else None


def hash_api_key(api_key: str) -> str:
# NOTE: no salt is needed, as the API key is randoml py generated
# NOTE: no salt is needed, as the API key is randomly generated
# and overlaps are impossible
return sha256_crypt.hash(api_key, salt="", rounds=API_KEY_HASH_ROUNDS)

Expand Down
4 changes: 2 additions & 2 deletions backend/ee/danswer/server/middleware/tenant_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.db.engine import is_valid_schema_name
from ee.danswer.auth.api_key import extract_tenant_from_api_key
from ee.danswer.auth.api_key import extract_tenant_from_api_key_header
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
Expand All @@ -37,7 +37,7 @@ async def set_tenant_id(

def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter) -> str:
# First check for API key
tenant_id = extract_tenant_from_api_key(request)
tenant_id = extract_tenant_from_api_key_header(request)
if tenant_id is not None:
return tenant_id

Expand Down

0 comments on commit 67e347a

Please sign in to comment.