Skip to content

Commit

Permalink
Multi tenant vespa (#2762)
Browse files Browse the repository at this point in the history
* add vespa multi tenancy

* k

* formatting

* Billing (#2667)

* k

* data -> control

* nit

* nit: error handling

* auth + app

* nit: color standardization

* nit

* nit: typing

* k

* k

* feat: functional upgrading

* feat: add block for downgrading to seats < active users

* add auth

* remove accomplished todo + prints

* nit

* tiny nit

* nit: centralize security

* add tenant expulsion/gating + invite user -> increment billing seat no.

* add cloud configs

* k

* k

* nit: update

* k

* k

* k

* k

* nit
  • Loading branch information
pablodanswer authored Oct 12, 2024
1 parent 7eafdae commit 20df20a
Show file tree
Hide file tree
Showing 44 changed files with 1,456 additions and 600 deletions.
1 change: 1 addition & 0 deletions backend/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
COPY ./danswer /app/danswer
COPY ./shared_configs /app/shared_configs
COPY ./alembic /app/alembic
COPY ./alembic_tenants /app/alembic_tenants
COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf

Expand Down
33 changes: 5 additions & 28 deletions backend/danswer/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import jwt
from email_validator import EmailNotValidError
from email_validator import EmailUndeliverableError
from email_validator import validate_email
from fastapi import APIRouter
from fastapi import Depends
Expand Down Expand Up @@ -41,10 +42,8 @@
from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserUpdate
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DATA_PLANE_SECRET
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import EMAIL_FROM
from danswer.configs.app_configs import EXPECTED_API_KEY
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from danswer.configs.app_configs import SECRET_JWT_KEY
Expand Down Expand Up @@ -129,7 +128,10 @@ def verify_email_is_invited(email: str) -> None:
if not email:
raise PermissionError("Email must be specified")

email_info = validate_email(email) # can raise EmailNotValidError
try:
email_info = validate_email(email)
except EmailUndeliverableError:
raise PermissionError("Email is not valid")

for email_whitelist in whitelist:
try:
Expand Down Expand Up @@ -652,28 +654,3 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
def get_default_admin_user_emails_() -> list[str]:
# No default seeding available for Danswer MIT
return []


async def control_plane_dep(request: Request) -> None:
api_key = request.headers.get("X-API-KEY")
if api_key != EXPECTED_API_KEY:
logger.warning("Invalid API key")
raise HTTPException(status_code=401, detail="Invalid API key")

auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
logger.warning("Invalid authorization header")
raise HTTPException(status_code=401, detail="Invalid authorization header")

token = auth_header.split(" ")[1]
try:
payload = jwt.decode(token, DATA_PLANE_SECRET, algorithms=["HS256"])
if payload.get("scope") != "tenant:create":
logger.warning("Insufficient permissions")
raise HTTPException(status_code=403, detail="Insufficient permissions")
except jwt.ExpiredSignatureError:
logger.warning("Token has expired")
raise HTTPException(status_code=401, detail="Token has expired")
except jwt.InvalidTokenError:
logger.warning("Invalid token")
raise HTTPException(status_code=401, detail="Invalid token")
10 changes: 9 additions & 1 deletion backend/danswer/background/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.document_index.vespa.index import VespaIndex
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.utils.logger import setup_logger
Expand Down Expand Up @@ -484,7 +485,14 @@ def update_loop(
f"Processing {'index attempts' if tenant_id is None else f'tenant {tenant_id}'}"
)
with get_session_with_tenant(tenant_id) as db_session:
check_index_swap(db_session=db_session)
index_to_expire = check_index_swap(db_session=db_session)

if index_to_expire and tenant_id and MULTI_TENANT:
VespaIndex.delete_entries_by_tenant_id(
tenant_id=tenant_id,
index_name=index_to_expire.index_name,
)

if not MULTI_TENANT:
search_settings = get_current_search_settings(db_session)
if search_settings.provider_type is None:
Expand Down
24 changes: 20 additions & 4 deletions backend/danswer/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,11 +423,27 @@
AZURE_DALLE_DEPLOYMENT_NAME = os.environ.get("AZURE_DALLE_DEPLOYMENT_NAME")


# Cloud configuration

# Multi-tenancy configuration
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "")
ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"

# Security and authentication
SECRET_JWT_KEY = os.environ.get(
"SECRET_JWT_KEY", ""
) # Used for encryption of the JWT token for user's tenant context
DATA_PLANE_SECRET = os.environ.get(
"DATA_PLANE_SECRET", ""
) # Used for secure communication between the control and data plane
EXPECTED_API_KEY = os.environ.get(
"EXPECTED_API_KEY", ""
) # Additional security check for the control plane API

DATA_PLANE_SECRET = os.environ.get("DATA_PLANE_SECRET", "")
EXPECTED_API_KEY = os.environ.get("EXPECTED_API_KEY", "")
# API configuration
CONTROL_PLANE_API_BASE_URL = os.environ.get(
"CONTROL_PLANE_API_BASE_URL", "http://localhost:8082"
)

ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"
# JWT configuration
JWT_ALGORITHM = "HS256"
16 changes: 14 additions & 2 deletions backend/danswer/db/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import Session

from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserRole
from danswer.db.engine import get_async_session
from danswer.db.engine import get_async_session_with_tenant
Expand All @@ -33,10 +35,20 @@ def get_default_admin_user_emails() -> list[str]:
return get_default_admin_user_emails_fn()


def get_total_users(db_session: Session) -> int:
"""
Returns the total number of users in the system.
This is the sum of users and invited users.
"""
user_count = db_session.query(User).count()
invited_users = len(get_invited_users())
return user_count + invited_users


async def get_user_count() -> int:
async with get_async_session_with_tenant() as asession:
async with get_async_session_with_tenant() as session:
stmt = select(func.count(User.id))
result = await asession.execute(stmt)
result = await session.execute(stmt)
user_count = result.scalar()
if user_count is None:
raise RuntimeError("Was not able to fetch the user count.")
Expand Down
4 changes: 2 additions & 2 deletions backend/danswer/db/search_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import IndexAttempt
Expand Down Expand Up @@ -152,7 +152,7 @@ def get_all_search_settings(db_session: Session) -> list[SearchSettings]:

def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
if db_session is None:
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant() as db_session:
search_settings = get_current_search_settings(db_session)
else:
search_settings = get_current_search_settings(db_session)
Expand Down
11 changes: 9 additions & 2 deletions backend/danswer/db/swap_index.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from sqlalchemy.orm import Session

from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.constants import KV_REINDEX_KEY
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import resync_cc_pair
Expand All @@ -8,16 +9,18 @@
from danswer.db.index_attempt import (
count_unique_cc_pairs_with_successful_index_attempts,
)
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.search_settings import update_search_settings_status
from danswer.key_value_store.factory import get_kv_store
from danswer.utils.logger import setup_logger


logger = setup_logger()


def check_index_swap(db_session: Session) -> None:
def check_index_swap(db_session: Session) -> SearchSettings | None:
"""Get count of cc-pairs and count of successful index_attempts for the
new model grouped by connector + credential, if it's the same, then assume
new index is done building. If so, swap the indices and expire the old one."""
Expand All @@ -27,7 +30,7 @@ def check_index_swap(db_session: Session) -> None:
search_settings = get_secondary_search_settings(db_session)

if not search_settings:
return
return None

unique_cc_indexings = count_unique_cc_pairs_with_successful_index_attempts(
search_settings_id=search_settings.id, db_session=db_session
Expand Down Expand Up @@ -63,3 +66,7 @@ def check_index_swap(db_session: Session) -> None:
# Recount aggregates
for cc_pair in all_cc_pairs:
resync_cc_pair(cc_pair, db_session=db_session)

if MULTI_TENANT:
return now_old_search_settings
return None
11 changes: 11 additions & 0 deletions backend/danswer/document_index/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,17 @@ def ensure_indices_exist(
"""
raise NotImplementedError

@staticmethod
@abc.abstractmethod
def register_multitenant_indices(
indices: list[str],
embedding_dims: list[int],
) -> None:
"""
Register multitenant indices with the document index.
"""
raise NotImplementedError


class Indexable(abc.ABC):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
schema DANSWER_CHUNK_NAME {
document DANSWER_CHUNK_NAME {
TENANT_ID_REPLACEMENT
# Not to be confused with the UUID generated for this chunk which is called documentid by default
field document_id type string {
indexing: summary | attribute
Expand Down
Loading

0 comments on commit 20df20a

Please sign in to comment.