From 7ef0bc79dcf6ff77bbefdd57cb5e5314e03c0a9b Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Mon, 4 Nov 2024 13:44:01 -0800 Subject: [PATCH 01/12] add provisioning on data plane --- backend/danswer/auth/users.py | 1 + backend/danswer/main.py | 21 +- backend/ee/danswer/auth/tenant.py | 0 .../ee/danswer/server/tenants/provisioning.py | 191 +++++++++++++----- .../ee/danswer/server/tenants/registration.py | 91 +++++++++ 5 files changed, 251 insertions(+), 53 deletions(-) create mode 100644 backend/ee/danswer/auth/tenant.py create mode 100644 backend/ee/danswer/server/tenants/registration.py diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 0cb4ae2326c..9cfaa2dec6a 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -245,6 +245,7 @@ async def create( else POSTGRES_DEFAULT_SCHEMA ) except exceptions.UserNotExists: + # We should provision a tenant raise HTTPException(status_code=401, detail="User not found") if not tenant_id: diff --git a/backend/danswer/main.py b/backend/danswer/main.py index ae18ab3ccf2..434112c5d5e 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -277,12 +277,21 @@ def get_application() -> FastAPI: prefix="/auth", tags=["auth"], ) - include_router_with_global_prefix_prepended( - application, - fastapi_users.get_register_router(UserRead, UserCreate), - prefix="/auth", - tags=["auth"], - ) + if not MULTI_TENANT: + include_router_with_global_prefix_prepended( + application, + fastapi_users.get_register_router(UserRead, UserCreate), + prefix="/auth", + tags=["auth"], + ) + else: + include_router_with_global_prefix_prepended( + application, + fastapi_users.get_register_router(UserRead, UserCreate), + prefix="/auth", + tags=["auth"], + ) + # include_router_with_global_prefix_prepended( application, fastapi_users.get_reset_password_router(), diff --git a/backend/ee/danswer/auth/tenant.py b/backend/ee/danswer/auth/tenant.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/ee/danswer/server/tenants/provisioning.py b/backend/ee/danswer/server/tenants/provisioning.py index 9106821b5a5..eb267e4b6bb 100644 --- a/backend/ee/danswer/server/tenants/provisioning.py +++ b/backend/ee/danswer/server/tenants/provisioning.py @@ -1,12 +1,18 @@ +import logging import os +import uuid from types import SimpleNamespace +import aiohttp # Async HTTP client +from fastapi import HTTPException from sqlalchemy import text from sqlalchemy.orm import Session from sqlalchemy.schema import CreateSchema from alembic import command from alembic.config import Config +from danswer.configs.app_configs import CONTROL_PLANE_API_BASE_URL +from danswer.configs.app_configs import EXPECTED_API_KEY from danswer.db.engine import build_connection_string from danswer.db.engine import get_session_with_tenant from danswer.db.engine import get_sqlalchemy_engine @@ -15,14 +21,150 @@ from danswer.db.models import UserTenantMapping from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest from danswer.server.manage.llm.models import LLMProviderUpsertRequest -from danswer.utils.logger import setup_logger +from danswer.setup import setup_danswer from ee.danswer.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY from ee.danswer.configs.app_configs import COHERE_DEFAULT_API_KEY from ee.danswer.configs.app_configs import OPENAI_DEFAULT_API_KEY +from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.enums import EmbeddingProvider -logger = setup_logger() +logger = logging.getLogger(__name__) + + +def drop_schema(tenant_id: str) -> None: + with get_sqlalchemy_engine().connect() as connection: + connection.execute(text(f"DROP SCHEMA IF EXISTS {tenant_id} CASCADE")) + + +class TenantProvisioningService: + async def provision_tenant(self, email: str) -> str: + tenant_id = str(uuid.uuid4()) # Generate new tenant ID + + # Provision tenant on data plane + await self._provision_on_data_plane(tenant_id, email) + + # Notify control plane + await self._notify_control_plane(tenant_id, email) + + return tenant_id + + async def _provision_on_data_plane(self, tenant_id: str, email: str) -> None: + if not MULTI_TENANT: + raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled") + + if user_owns_a_tenant(email): + raise HTTPException( + status_code=409, detail="User already belongs to an organization" + ) + + logger.info(f"Provisioning tenant: {tenant_id}") + token = None + + try: + if not ensure_schema_exists(tenant_id): + logger.info(f"Created schema for tenant {tenant_id}") + else: + logger.info(f"Schema already exists for tenant {tenant_id}") + + token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) + run_alembic_migrations(tenant_id) + + with get_session_with_tenant(tenant_id) as db_session: + setup_danswer(db_session, tenant_id) + configure_default_api_keys(db_session) + + add_users_to_tenant([email], tenant_id) + + except Exception as e: + logger.exception(f"Failed to create tenant {tenant_id}: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to create tenant: {str(e)}" + ) + finally: + if token is not None: + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + + async def _notify_control_plane(self, tenant_id: str, email: str) -> None: + headers = { + "Authorization": f"Bearer {EXPECTED_API_KEY}", # Replace with your control plane API key + "Content-Type": "application/json", + } + payload = {"tenant_id": tenant_id, "email": email} + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{CONTROL_PLANE_API_BASE_URL}/tenants/create", # Replace with your control plane URL + headers=headers, + json=payload, + ) as response: + if response.status != 200: + error_text = await response.text() + logger.error(f"Control plane tenant creation failed: {error_text}") + raise Exception( + f"Failed to create tenant on control plane: {error_text}" + ) + + async def rollback_tenant_provisioning(self, tenant_id: str) -> None: + # Logic to rollback tenant provisioning on data plane + logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}") + try: + # Drop the tenant's schema to rollback provisioning + drop_schema(tenant_id) + # Remove tenant mapping + with Session(get_sqlalchemy_engine()) as db_session: + db_session.query(UserTenantMapping).filter( + UserTenantMapping.tenant_id == tenant_id + ).delete() + db_session.commit() + except Exception as e: + logger.error(f"Failed to rollback tenant provisioning: {e}") + + +# For now, we're implementing a primitive mapping between users and tenants. +# This function is only used to determine a user's relationship to a tenant upon creation (implying ownership). +def user_owns_a_tenant(email: str) -> bool: + with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: + result = ( + db_session.query(UserTenantMapping) + .filter(UserTenantMapping.email == email) + .first() + ) + return result is not None + + +def add_users_to_tenant(emails: list[str], tenant_id: str) -> None: + with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: + try: + for email in emails: + db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id)) + except Exception as e: + logger.exception(f"Failed to add users to tenant {tenant_id}: {str(e)}") + db_session.commit() + + +def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None: + with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: + try: + mappings_to_delete = ( + db_session.query(UserTenantMapping) + .filter( + UserTenantMapping.email.in_(emails), + UserTenantMapping.tenant_id == tenant_id, + ) + .all() + ) + + for mapping in mappings_to_delete: + db_session.delete(mapping) + + db_session.commit() + except Exception as e: + logger.exception( + f"Failed to remove users from tenant {tenant_id}: {str(e)}" + ) + db_session.rollback() def run_alembic_migrations(schema_name: str) -> None: @@ -98,48 +240,3 @@ def ensure_schema_exists(tenant_id: str) -> bool: db_session.execute(stmt) return True return False - - -# For now, we're implementing a primitive mapping between users and tenants. -# This function is only used to determine a user's relationship to a tenant upon creation (implying ownership). -def user_owns_a_tenant(email: str) -> bool: - with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: - result = ( - db_session.query(UserTenantMapping) - .filter(UserTenantMapping.email == email) - .first() - ) - return result is not None - - -def add_users_to_tenant(emails: list[str], tenant_id: str) -> None: - with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: - try: - for email in emails: - db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id)) - except Exception as e: - logger.exception(f"Failed to add users to tenant {tenant_id}: {str(e)}") - db_session.commit() - - -def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None: - with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: - try: - mappings_to_delete = ( - db_session.query(UserTenantMapping) - .filter( - UserTenantMapping.email.in_(emails), - UserTenantMapping.tenant_id == tenant_id, - ) - .all() - ) - - for mapping in mappings_to_delete: - db_session.delete(mapping) - - db_session.commit() - except Exception as e: - logger.exception( - f"Failed to remove users from tenant {tenant_id}: {str(e)}" - ) - db_session.rollback() diff --git a/backend/ee/danswer/server/tenants/registration.py b/backend/ee/danswer/server/tenants/registration.py new file mode 100644 index 00000000000..399a42a9925 --- /dev/null +++ b/backend/ee/danswer/server/tenants/registration.py @@ -0,0 +1,91 @@ +import logging + +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from fastapi import Request +from fastapi import Response +from fastapi_users import exceptions +from fastapi_users.router.common import ErrorCode + +from danswer.auth.schemas import UserCreate +from danswer.auth.schemas import UserRead +from danswer.auth.users import auth_backend +from danswer.auth.users import get_jwt_strategy +from danswer.auth.users import get_tenant_id_for_email +from danswer.auth.users import get_user_manager +from danswer.auth.users import UserManager +from danswer.db.auth import SQLAlchemyUserAdminDB +from danswer.db.engine import get_async_session_with_tenant +from danswer.db.models import OAuthAccount +from danswer.db.models import User +from ee.danswer.server.tenants.provisioning import TenantProvisioningService +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR + +# Import necessary modules and functions + +logger = logging.getLogger(__name__) + +router = APIRouter() +tenant_service = TenantProvisioningService() + + +@router.post("/register", response_model=UserRead) +async def register( + user_create: UserCreate, + request: Request, + response: Response, + user_manager: UserManager = Depends(get_user_manager), + # Include any other dependencies you need +) -> UserRead: + try: + # Check if user already belongs to a tenant + tenant_id = get_tenant_id_for_email(user_create.email) + except exceptions.UserNotExists: + # User does not belong to a tenant; need to provision a new tenant + tenant_id = None + + if not tenant_id: + # Provision the tenant + try: + tenant_id = await tenant_service.provision_tenant(user_create.email) + except Exception as e: + logger.error(f"Tenant provisioning failed: {e}") + raise HTTPException(status_code=500, detail="Failed to provision tenant.") + + # Proceed with user creation + if tenant_id is None: + raise HTTPException(status_code=500, detail="Failed to provision tenant.") + token = None + try: + # Set the tenant ID context variable + token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) + + async with get_async_session_with_tenant(tenant_id) as db_session: + # Set up user manager with tenant-specific user DB + tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount) + user_manager.user_db = tenant_user_db + user_manager.database = tenant_user_db + + # Create the user + user = await user_manager.create(user_create, request=request) + + # Optional: Log the user in automatically + await auth_backend.login(get_jwt_strategy(), user) + + # Convert User model to UserRead schema before returning + return UserRead.model_validate(user) + + except exceptions.UserAlreadyExists: + raise HTTPException( + status_code=400, + detail=ErrorCode.REGISTER_USER_ALREADY_EXISTS, + ) + except Exception as e: + logger.error(f"User creation failed: {e}") + # Optionally rollback tenant provisioning + await tenant_service.rollback_tenant_provisioning(tenant_id) + raise HTTPException(status_code=500, detail="Failed to create user.") + finally: + if token: + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) From 91f92aa04eefd5572029fd0ad6f500c46da7431a Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Mon, 4 Nov 2024 14:40:17 -0800 Subject: [PATCH 02/12] functional but scrappy --- backend/danswer/auth/users.py | 60 +++++++----- backend/danswer/main.py | 23 ++--- .../ee/danswer/server/tenants/provisioning.py | 8 +- .../ee/danswer/server/tenants/registration.py | 91 ------------------- web/src/app/layout.tsx | 1 + 5 files changed, 50 insertions(+), 133 deletions(-) delete mode 100644 backend/ee/danswer/server/tenants/registration.py diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 9cfaa2dec6a..4336d99d38f 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -56,7 +56,6 @@ from danswer.auth.invited_users import get_invited_users from danswer.auth.schemas import UserCreate from danswer.auth.schemas import UserRole -from danswer.auth.schemas import UserUpdate from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import DISABLE_AUTH from danswer.configs.app_configs import DISABLE_VERIFICATION @@ -93,6 +92,7 @@ from danswer.utils.telemetry import optional_telemetry from danswer.utils.telemetry import RecordType from danswer.utils.variable_functionality import fetch_versioned_implementation +from ee.danswer.server.tenants.provisioning import TenantProvisioningService from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR @@ -245,14 +245,25 @@ async def create( else POSTGRES_DEFAULT_SCHEMA ) except exceptions.UserNotExists: - # We should provision a tenant - raise HTTPException(status_code=401, detail="User not found") + # Tenant does not exist; provision a new tenant + tenant_provisioning_service = TenantProvisioningService() + try: + tenant_id = await tenant_provisioning_service.provision_tenant( + user_create.email + ) + except Exception as e: + logger.error(f"Tenant provisioning failed: {e}") + raise HTTPException( + status_code=500, detail="Failed to provision tenant." + ) if not tenant_id: raise HTTPException( status_code=401, detail="User does not belong to an organization" ) + # Proceed with user creation + logger.error(f"Creating user {user_create.email} in tenant {tenant_id}") async with get_async_session_with_tenant(tenant_id) as db_session: token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) @@ -272,28 +283,15 @@ async def create( user_create.role = UserRole.ADMIN else: user_create.role = UserRole.BASIC - user = None + try: - user = await super().create(user_create, safe=safe, request=request) # type: ignore + user = await super().create(user_create, safe=safe, request=request) except exceptions.UserAlreadyExists: - user = await self.get_by_email(user_create.email) - # Handle case where user has used product outside of web and is now creating an account through web - if ( - not user.has_web_login - and hasattr(user_create, "has_web_login") - and user_create.has_web_login - ): - user_update = UserUpdate( - password=user_create.password, - has_web_login=True, - role=user_create.role, - is_verified=user_create.is_verified, - ) - user = await self.update(user_update, user) - else: - raise exceptions.UserAlreadyExists() + # ... existing handling of this case + raise exceptions.UserAlreadyExists() CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + return user async def oauth_callback( @@ -317,12 +315,24 @@ async def oauth_callback( else POSTGRES_DEFAULT_SCHEMA ) except exceptions.UserNotExists: - raise HTTPException(status_code=401, detail="User not found") + # Tenant does not exist; provision a new tenant + tenant_provisioning_service = TenantProvisioningService() + try: + tenant_id = await tenant_provisioning_service.provision_tenant( + account_email + ) + except Exception as e: + logger.error(f"Tenant provisioning failed: {e}") + raise HTTPException( + status_code=500, detail="Failed to provision tenant." + ) if not tenant_id: raise HTTPException(status_code=401, detail="User not found") + # Proceed with the tenant context token = None + logger.error(f"zabozeezaaGetting async session with tenant {tenant_id}") async with get_async_session_with_tenant(tenant_id) as db_session: token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) @@ -372,9 +382,9 @@ async def oauth_callback( # Explicitly set the Postgres schema for this session to ensure # OAuth account creation happens in the correct tenant schema await db_session.execute(text(f'SET search_path = "{tenant_id}"')) - user = await self.user_db.add_oauth_account( - user, oauth_account_dict - ) + + # Add OAuth account + await self.user_db.add_oauth_account(user, oauth_account_dict) await self.on_after_register(user, request) else: diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 434112c5d5e..06ce7bf4092 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -277,21 +277,14 @@ def get_application() -> FastAPI: prefix="/auth", tags=["auth"], ) - if not MULTI_TENANT: - include_router_with_global_prefix_prepended( - application, - fastapi_users.get_register_router(UserRead, UserCreate), - prefix="/auth", - tags=["auth"], - ) - else: - include_router_with_global_prefix_prepended( - application, - fastapi_users.get_register_router(UserRead, UserCreate), - prefix="/auth", - tags=["auth"], - ) - # + + include_router_with_global_prefix_prepended( + application, + fastapi_users.get_register_router(UserRead, UserCreate), + prefix="/auth", + tags=["auth"], + ) + include_router_with_global_prefix_prepended( application, fastapi_users.get_reset_password_router(), diff --git a/backend/ee/danswer/server/tenants/provisioning.py b/backend/ee/danswer/server/tenants/provisioning.py index eb267e4b6bb..1c84e10f493 100644 --- a/backend/ee/danswer/server/tenants/provisioning.py +++ b/backend/ee/danswer/server/tenants/provisioning.py @@ -1,3 +1,4 @@ +import asyncio import logging import os import uuid @@ -27,6 +28,7 @@ from ee.danswer.configs.app_configs import OPENAI_DEFAULT_API_KEY from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA +from shared_configs.configs import TENANT_ID_PREFIX from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.enums import EmbeddingProvider @@ -40,7 +42,7 @@ def drop_schema(tenant_id: str) -> None: class TenantProvisioningService: async def provision_tenant(self, email: str) -> str: - tenant_id = str(uuid.uuid4()) # Generate new tenant ID + tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4()) # Generate new tenant ID # Provision tenant on data plane await self._provision_on_data_plane(tenant_id, email) @@ -69,7 +71,9 @@ async def _provision_on_data_plane(self, tenant_id: str, email: str) -> None: logger.info(f"Schema already exists for tenant {tenant_id}") token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) - run_alembic_migrations(tenant_id) + + # Await the Alembic migrations + await asyncio.to_thread(run_alembic_migrations, tenant_id) with get_session_with_tenant(tenant_id) as db_session: setup_danswer(db_session, tenant_id) diff --git a/backend/ee/danswer/server/tenants/registration.py b/backend/ee/danswer/server/tenants/registration.py deleted file mode 100644 index 399a42a9925..00000000000 --- a/backend/ee/danswer/server/tenants/registration.py +++ /dev/null @@ -1,91 +0,0 @@ -import logging - -from fastapi import APIRouter -from fastapi import Depends -from fastapi import HTTPException -from fastapi import Request -from fastapi import Response -from fastapi_users import exceptions -from fastapi_users.router.common import ErrorCode - -from danswer.auth.schemas import UserCreate -from danswer.auth.schemas import UserRead -from danswer.auth.users import auth_backend -from danswer.auth.users import get_jwt_strategy -from danswer.auth.users import get_tenant_id_for_email -from danswer.auth.users import get_user_manager -from danswer.auth.users import UserManager -from danswer.db.auth import SQLAlchemyUserAdminDB -from danswer.db.engine import get_async_session_with_tenant -from danswer.db.models import OAuthAccount -from danswer.db.models import User -from ee.danswer.server.tenants.provisioning import TenantProvisioningService -from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR - -# Import necessary modules and functions - -logger = logging.getLogger(__name__) - -router = APIRouter() -tenant_service = TenantProvisioningService() - - -@router.post("/register", response_model=UserRead) -async def register( - user_create: UserCreate, - request: Request, - response: Response, - user_manager: UserManager = Depends(get_user_manager), - # Include any other dependencies you need -) -> UserRead: - try: - # Check if user already belongs to a tenant - tenant_id = get_tenant_id_for_email(user_create.email) - except exceptions.UserNotExists: - # User does not belong to a tenant; need to provision a new tenant - tenant_id = None - - if not tenant_id: - # Provision the tenant - try: - tenant_id = await tenant_service.provision_tenant(user_create.email) - except Exception as e: - logger.error(f"Tenant provisioning failed: {e}") - raise HTTPException(status_code=500, detail="Failed to provision tenant.") - - # Proceed with user creation - if tenant_id is None: - raise HTTPException(status_code=500, detail="Failed to provision tenant.") - token = None - try: - # Set the tenant ID context variable - token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) - - async with get_async_session_with_tenant(tenant_id) as db_session: - # Set up user manager with tenant-specific user DB - tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount) - user_manager.user_db = tenant_user_db - user_manager.database = tenant_user_db - - # Create the user - user = await user_manager.create(user_create, request=request) - - # Optional: Log the user in automatically - await auth_backend.login(get_jwt_strategy(), user) - - # Convert User model to UserRead schema before returning - return UserRead.model_validate(user) - - except exceptions.UserAlreadyExists: - raise HTTPException( - status_code=400, - detail=ErrorCode.REGISTER_USER_ALREADY_EXISTS, - ) - except Exception as e: - logger.error(f"User creation failed: {e}") - # Optionally rollback tenant provisioning - await tenant_service.rollback_tenant_provisioning(tenant_id) - raise HTTPException(status_code=500, detail="Failed to create user.") - finally: - if token: - CURRENT_TENANT_ID_CONTEXTVAR.reset(token) diff --git a/web/src/app/layout.tsx b/web/src/app/layout.tsx index 8918aa5a636..8a90ace9e09 100644 --- a/web/src/app/layout.tsx +++ b/web/src/app/layout.tsx @@ -92,6 +92,7 @@ export default async function RootLayout({ ); + console.log("combinedSettings", combinedSettings); if (!combinedSettings) { return getPageContent( From 37349c9d5c41d3f6f0469807bf348031570508b9 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Mon, 4 Nov 2024 14:42:31 -0800 Subject: [PATCH 03/12] minor cleanup --- backend/danswer/auth/users.py | 22 +++++++++++++++++----- web/src/app/layout.tsx | 1 - 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 4336d99d38f..a1b3c50fd83 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -56,6 +56,7 @@ from danswer.auth.invited_users import get_invited_users from danswer.auth.schemas import UserCreate from danswer.auth.schemas import UserRole +from danswer.auth.schemas import UserUpdate from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import DISABLE_AUTH from danswer.configs.app_configs import DISABLE_VERIFICATION @@ -262,8 +263,6 @@ async def create( status_code=401, detail="User does not belong to an organization" ) - # Proceed with user creation - logger.error(f"Creating user {user_create.email} in tenant {tenant_id}") async with get_async_session_with_tenant(tenant_id) as db_session: token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) @@ -287,8 +286,22 @@ async def create( try: user = await super().create(user_create, safe=safe, request=request) except exceptions.UserAlreadyExists: - # ... existing handling of this case - raise exceptions.UserAlreadyExists() + user = await self.get_by_email(user_create.email) + # Handle case where user has used product outside of web and is now creating an account through web + if ( + not user.has_web_login + and hasattr(user_create, "has_web_login") + and user_create.has_web_login + ): + user_update = UserUpdate( + password=user_create.password, + has_web_login=True, + role=user_create.role, + is_verified=user_create.is_verified, + ) + user = await self.update(user_update, user) + else: + raise exceptions.UserAlreadyExists() CURRENT_TENANT_ID_CONTEXTVAR.reset(token) @@ -332,7 +345,6 @@ async def oauth_callback( # Proceed with the tenant context token = None - logger.error(f"zabozeezaaGetting async session with tenant {tenant_id}") async with get_async_session_with_tenant(tenant_id) as db_session: token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) diff --git a/web/src/app/layout.tsx b/web/src/app/layout.tsx index 8a90ace9e09..8918aa5a636 100644 --- a/web/src/app/layout.tsx +++ b/web/src/app/layout.tsx @@ -92,7 +92,6 @@ export default async function RootLayout({ ); - console.log("combinedSettings", combinedSettings); if (!combinedSettings) { return getPageContent( From 2f34c598dcd3bd6094773b04a96d00492f89d1f4 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Mon, 4 Nov 2024 14:47:45 -0800 Subject: [PATCH 04/12] minor clean up --- backend/danswer/auth/users.py | 73 +++++++++---------- .../ee/danswer/server/tenants/provisioning.py | 10 +-- 2 files changed, 41 insertions(+), 42 deletions(-) diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index a1b3c50fd83..afe369f7c37 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -239,29 +239,29 @@ async def create( safe: bool = False, request: Optional[Request] = None, ) -> User: - try: - tenant_id = ( - get_tenant_id_for_email(user_create.email) - if MULTI_TENANT - else POSTGRES_DEFAULT_SCHEMA - ) - except exceptions.UserNotExists: - # Tenant does not exist; provision a new tenant - tenant_provisioning_service = TenantProvisioningService() + if MULTI_TENANT: try: - tenant_id = await tenant_provisioning_service.provision_tenant( - user_create.email - ) - except Exception as e: - logger.error(f"Tenant provisioning failed: {e}") + tenant_id = get_tenant_id_for_email(user_create.email) + + except exceptions.UserNotExists: + # If tenant does not exist and in Multi tenant mode, provision a new tenant + tenant_provisioning_service = TenantProvisioningService() + try: + tenant_id = await tenant_provisioning_service.provision_tenant( + user_create.email + ) + except Exception as e: + logger.error(f"Tenant provisioning failed: {e}") + raise HTTPException( + status_code=500, detail="Failed to provision tenant." + ) + + if not tenant_id: raise HTTPException( - status_code=500, detail="Failed to provision tenant." + status_code=401, detail="User does not belong to an organization" ) - - if not tenant_id: - raise HTTPException( - status_code=401, detail="User does not belong to an organization" - ) + else: + tenant_id = POSTGRES_DEFAULT_SCHEMA async with get_async_session_with_tenant(tenant_id) as db_session: token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) @@ -321,24 +321,23 @@ async def oauth_callback( is_verified_by_default: bool = False, ) -> models.UOAP: # Get tenant_id from mapping table - try: - tenant_id = ( - get_tenant_id_for_email(account_email) - if MULTI_TENANT - else POSTGRES_DEFAULT_SCHEMA - ) - except exceptions.UserNotExists: - # Tenant does not exist; provision a new tenant - tenant_provisioning_service = TenantProvisioningService() + if MULTI_TENANT: try: - tenant_id = await tenant_provisioning_service.provision_tenant( - account_email - ) - except Exception as e: - logger.error(f"Tenant provisioning failed: {e}") - raise HTTPException( - status_code=500, detail="Failed to provision tenant." - ) + tenant_id = get_tenant_id_for_email(account_email) + except exceptions.UserNotExists: + # Tenant does not exist; provision a new tenant + tenant_provisioning_service = TenantProvisioningService() + try: + tenant_id = await tenant_provisioning_service.provision_tenant( + account_email + ) + except Exception as e: + logger.error(f"Tenant provisioning failed: {e}") + raise HTTPException( + status_code=500, detail="Failed to provision tenant." + ) + else: + tenant_id = POSTGRES_DEFAULT_SCHEMA if not tenant_id: raise HTTPException(status_code=401, detail="User not found") diff --git a/backend/ee/danswer/server/tenants/provisioning.py b/backend/ee/danswer/server/tenants/provisioning.py index 1c84e10f493..7aa234a2d02 100644 --- a/backend/ee/danswer/server/tenants/provisioning.py +++ b/backend/ee/danswer/server/tenants/provisioning.py @@ -35,11 +35,6 @@ logger = logging.getLogger(__name__) -def drop_schema(tenant_id: str) -> None: - with get_sqlalchemy_engine().connect() as connection: - connection.execute(text(f"DROP SCHEMA IF EXISTS {tenant_id} CASCADE")) - - class TenantProvisioningService: async def provision_tenant(self, email: str) -> str: tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4()) # Generate new tenant ID @@ -244,3 +239,8 @@ def ensure_schema_exists(tenant_id: str) -> bool: db_session.execute(stmt) return True return False + + +def drop_schema(tenant_id: str) -> None: + with get_sqlalchemy_engine().connect() as connection: + connection.execute(text(f"DROP SCHEMA IF EXISTS {tenant_id} CASCADE")) From 1dc80d6b0262c96b01653b9ccea6d431ab9b9bed Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Mon, 4 Nov 2024 14:54:55 -0800 Subject: [PATCH 05/12] k --- backend/danswer/background/celery/apps/beat.py | 8 ++++---- backend/danswer/server/query_and_chat/chat_backend.py | 2 +- backend/danswer/server/query_and_chat/query_backend.py | 2 +- backend/ee/danswer/server/tenants/provisioning.py | 6 +++--- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/backend/danswer/background/celery/apps/beat.py b/backend/danswer/background/celery/apps/beat.py index 5ef887121dc..979cf07cbb1 100644 --- a/backend/danswer/background/celery/apps/beat.py +++ b/backend/danswer/background/celery/apps/beat.py @@ -119,10 +119,10 @@ def _update_tenant_tasks(self) -> None: else: logger.info("Schedule is up to date, no changes needed") - except (AttributeError, KeyError) as e: - logger.exception(f"Failed to process task configuration: {str(e)}") - except Exception as e: - logger.exception(f"Unexpected error updating tenant tasks: {str(e)}") + except (AttributeError, KeyError): + logger.exception("Failed to process task configuration") + except Exception: + logger.exception("Unexpected error updating tenant tasks") def _should_update_schedule( self, current_schedule: dict, new_schedule: dict diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index c1f4a7b3970..41176a0453f 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -359,7 +359,7 @@ def stream_generator() -> Generator[str, None, None]: yield json.dumps(packet) if isinstance(packet, dict) else packet except Exception as e: - logger.exception(f"Error in chat message streaming: {e}") + logger.exception("Error in chat message streaming") yield json.dumps({"error": str(e)}) finally: diff --git a/backend/danswer/server/query_and_chat/query_backend.py b/backend/danswer/server/query_and_chat/query_backend.py index 1b8d5dc4b5e..4d6767ac2e2 100644 --- a/backend/danswer/server/query_and_chat/query_backend.py +++ b/backend/danswer/server/query_and_chat/query_backend.py @@ -279,7 +279,7 @@ def stream_generator() -> Generator[str, None, None]: ): yield json.dumps(packet) if isinstance(packet, dict) else packet except Exception as e: - logger.exception(f"Error in search answer streaming: {e}") + logger.exception("Error in search answer streaming") yield json.dumps({"error": str(e)}) return StreamingResponse(stream_generator(), media_type="application/json") diff --git a/backend/ee/danswer/server/tenants/provisioning.py b/backend/ee/danswer/server/tenants/provisioning.py index 7aa234a2d02..4f9189f476c 100644 --- a/backend/ee/danswer/server/tenants/provisioning.py +++ b/backend/ee/danswer/server/tenants/provisioning.py @@ -77,7 +77,7 @@ async def _provision_on_data_plane(self, tenant_id: str, email: str) -> None: add_users_to_tenant([email], tenant_id) except Exception as e: - logger.exception(f"Failed to create tenant {tenant_id}: {str(e)}") + logger.exception(f"Failed to create tenant {tenant_id}") raise HTTPException( status_code=500, detail=f"Failed to create tenant: {str(e)}" ) @@ -138,8 +138,8 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None: try: for email in emails: db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id)) - except Exception as e: - logger.exception(f"Failed to add users to tenant {tenant_id}: {str(e)}") + except Exception: + logger.exception(f"Failed to add users to tenant {tenant_id}") db_session.commit() From 3f2951c029939338deecb066c0f1a8482755f6fe Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Mon, 4 Nov 2024 15:27:13 -0800 Subject: [PATCH 06/12] simplify --- backend/danswer/auth/users.py | 48 ++----------------- .../ee/danswer/server/tenants/provisioning.py | 27 +++++++++++ 2 files changed, 32 insertions(+), 43 deletions(-) diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index afe369f7c37..11387298a3b 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -93,7 +93,7 @@ from danswer.utils.telemetry import optional_telemetry from danswer.utils.telemetry import RecordType from danswer.utils.variable_functionality import fetch_versioned_implementation -from ee.danswer.server.tenants.provisioning import TenantProvisioningService +from ee.danswer.server.tenants.provisioning import get_or_create_tenant_id from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR @@ -239,29 +239,7 @@ async def create( safe: bool = False, request: Optional[Request] = None, ) -> User: - if MULTI_TENANT: - try: - tenant_id = get_tenant_id_for_email(user_create.email) - - except exceptions.UserNotExists: - # If tenant does not exist and in Multi tenant mode, provision a new tenant - tenant_provisioning_service = TenantProvisioningService() - try: - tenant_id = await tenant_provisioning_service.provision_tenant( - user_create.email - ) - except Exception as e: - logger.error(f"Tenant provisioning failed: {e}") - raise HTTPException( - status_code=500, detail="Failed to provision tenant." - ) - - if not tenant_id: - raise HTTPException( - status_code=401, detail="User does not belong to an organization" - ) - else: - tenant_id = POSTGRES_DEFAULT_SCHEMA + tenant_id = get_or_create_tenant_id(user_create.email) async with get_async_session_with_tenant(tenant_id) as db_session: token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) @@ -303,7 +281,8 @@ async def create( else: raise exceptions.UserAlreadyExists() - CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + finally: + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) return user @@ -320,24 +299,7 @@ async def oauth_callback( associate_by_email: bool = False, is_verified_by_default: bool = False, ) -> models.UOAP: - # Get tenant_id from mapping table - if MULTI_TENANT: - try: - tenant_id = get_tenant_id_for_email(account_email) - except exceptions.UserNotExists: - # Tenant does not exist; provision a new tenant - tenant_provisioning_service = TenantProvisioningService() - try: - tenant_id = await tenant_provisioning_service.provision_tenant( - account_email - ) - except Exception as e: - logger.error(f"Tenant provisioning failed: {e}") - raise HTTPException( - status_code=500, detail="Failed to provision tenant." - ) - else: - tenant_id = POSTGRES_DEFAULT_SCHEMA + tenant_id = get_or_create_tenant_id(account_email) if not tenant_id: raise HTTPException(status_code=401, detail="User not found") diff --git a/backend/ee/danswer/server/tenants/provisioning.py b/backend/ee/danswer/server/tenants/provisioning.py index 4f9189f476c..7b6790cc2f1 100644 --- a/backend/ee/danswer/server/tenants/provisioning.py +++ b/backend/ee/danswer/server/tenants/provisioning.py @@ -12,6 +12,8 @@ from alembic import command from alembic.config import Config +from danswer.auth.users import exceptions +from danswer.auth.users import get_tenant_id_for_email from danswer.configs.app_configs import CONTROL_PLANE_API_BASE_URL from danswer.configs.app_configs import EXPECTED_API_KEY from danswer.db.engine import build_connection_string @@ -32,9 +34,34 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.enums import EmbeddingProvider + logger = logging.getLogger(__name__) +async def get_or_create_tenant_id(email: str) -> str: + """Get existing tenant ID for an email or create a new tenant if none exists.""" + if not MULTI_TENANT: + return POSTGRES_DEFAULT_SCHEMA + + try: + tenant_id = get_tenant_id_for_email(email) + except exceptions.UserNotExists: + # If tenant does not exist and in Multi tenant mode, provision a new tenant + tenant_provisioning_service = TenantProvisioningService() + try: + tenant_id = await tenant_provisioning_service.provision_tenant(email) + except Exception as e: + logger.error(f"Tenant provisioning failed: {e}") + raise HTTPException(status_code=500, detail="Failed to provision tenant.") + + if not tenant_id: + raise HTTPException( + status_code=401, detail="User does not belong to an organization" + ) + + return tenant_id + + class TenantProvisioningService: async def provision_tenant(self, email: str) -> str: tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4()) # Generate new tenant ID From d52e353c7fb8cadbcbf976689b8265147b2468b6 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Mon, 4 Nov 2024 16:16:30 -0800 Subject: [PATCH 07/12] update provisioning --- backend/danswer/auth/users.py | 7 +- backend/danswer/server/manage/users.py | 2 +- backend/ee/danswer/server/tenants/api.py | 54 ---- .../ee/danswer/server/tenants/provisioning.py | 254 ++++++------------ .../server/tenants/schema_management.py | 76 ++++++ .../ee/danswer/server/tenants/user_mapping.py | 50 ++++ 6 files changed, 206 insertions(+), 237 deletions(-) create mode 100644 backend/ee/danswer/server/tenants/schema_management.py create mode 100644 backend/ee/danswer/server/tenants/user_mapping.py diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 11387298a3b..6cd15362239 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -98,7 +98,6 @@ from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR - logger = setup_logger() @@ -239,7 +238,7 @@ async def create( safe: bool = False, request: Optional[Request] = None, ) -> User: - tenant_id = get_or_create_tenant_id(user_create.email) + tenant_id = await get_or_create_tenant_id(user_create.email) async with get_async_session_with_tenant(tenant_id) as db_session: token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) @@ -262,7 +261,7 @@ async def create( user_create.role = UserRole.BASIC try: - user = await super().create(user_create, safe=safe, request=request) + user = await super().create(user_create, safe=safe, request=request) # type: ignore except exceptions.UserAlreadyExists: user = await self.get_by_email(user_create.email) # Handle case where user has used product outside of web and is now creating an account through web @@ -299,7 +298,7 @@ async def oauth_callback( associate_by_email: bool = False, is_verified_by_default: bool = False, ) -> models.UOAP: - tenant_id = get_or_create_tenant_id(account_email) + tenant_id = await get_or_create_tenant_id(account_email) if not tenant_id: raise HTTPException(status_code=401, detail="User not found") diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index 7802067b0ca..819781fc740 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -66,7 +66,7 @@ from ee.danswer.db.user_group import remove_curator_status__no_commit from ee.danswer.server.tenants.billing import register_tenant_users from ee.danswer.server.tenants.provisioning import add_users_to_tenant -from ee.danswer.server.tenants.provisioning import remove_users_from_tenant +from ee.danswer.server.tenants.user_mapping import remove_users_from_tenant from shared_configs.configs import MULTI_TENANT logger = setup_logger() diff --git a/backend/ee/danswer/server/tenants/api.py b/backend/ee/danswer/server/tenants/api.py index 8e79c0b37b1..e6d0e048c83 100644 --- a/backend/ee/danswer/server/tenants/api.py +++ b/backend/ee/danswer/server/tenants/api.py @@ -15,7 +15,6 @@ from danswer.db.users import get_user_by_email from danswer.server.settings.store import load_settings from danswer.server.settings.store import store_settings -from danswer.setup import setup_danswer from danswer.utils.logger import setup_logger from ee.danswer.auth.users import current_cloud_superuser from ee.danswer.configs.app_configs import STRIPE_SECRET_KEY @@ -23,15 +22,8 @@ from ee.danswer.server.tenants.billing import fetch_billing_information from ee.danswer.server.tenants.billing import fetch_tenant_stripe_information from ee.danswer.server.tenants.models import BillingInformation -from ee.danswer.server.tenants.models import CreateTenantRequest from ee.danswer.server.tenants.models import ImpersonateRequest from ee.danswer.server.tenants.models import ProductGatingRequest -from ee.danswer.server.tenants.provisioning import add_users_to_tenant -from ee.danswer.server.tenants.provisioning import configure_default_api_keys -from ee.danswer.server.tenants.provisioning import ensure_schema_exists -from ee.danswer.server.tenants.provisioning import run_alembic_migrations -from ee.danswer.server.tenants.provisioning import user_owns_a_tenant -from shared_configs.configs import MULTI_TENANT from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR stripe.api_key = STRIPE_SECRET_KEY @@ -40,52 +32,6 @@ router = APIRouter(prefix="/tenants") -@router.post("/create") -def create_tenant( - create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep) -) -> dict[str, str]: - if not MULTI_TENANT: - raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled") - - tenant_id = create_tenant_request.tenant_id - email = create_tenant_request.initial_admin_email - token = None - - if user_owns_a_tenant(email): - raise HTTPException( - status_code=409, detail="User already belongs to an organization" - ) - - try: - if not ensure_schema_exists(tenant_id): - logger.info(f"Created schema for tenant {tenant_id}") - else: - logger.info(f"Schema already exists for tenant {tenant_id}") - - token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) - run_alembic_migrations(tenant_id) - - with get_session_with_tenant(tenant_id) as db_session: - setup_danswer(db_session, tenant_id) - - configure_default_api_keys(db_session) - - add_users_to_tenant([email], tenant_id) - - return { - "status": "success", - "message": f"Tenant {tenant_id} created successfully", - } - except Exception as e: - logger.exception(f"Failed to create tenant {tenant_id}: {str(e)}") - raise HTTPException( - status_code=500, detail=f"Failed to create tenant: {str(e)}" - ) - finally: - if token is not None: - CURRENT_TENANT_ID_CONTEXTVAR.reset(token) - - @router.post("/product-gating") def gate_product( product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep) diff --git a/backend/ee/danswer/server/tenants/provisioning.py b/backend/ee/danswer/server/tenants/provisioning.py index 7b6790cc2f1..0cb3821bc24 100644 --- a/backend/ee/danswer/server/tenants/provisioning.py +++ b/backend/ee/danswer/server/tenants/provisioning.py @@ -1,22 +1,15 @@ import asyncio import logging -import os import uuid -from types import SimpleNamespace import aiohttp # Async HTTP client from fastapi import HTTPException -from sqlalchemy import text from sqlalchemy.orm import Session -from sqlalchemy.schema import CreateSchema -from alembic import command -from alembic.config import Config from danswer.auth.users import exceptions from danswer.auth.users import get_tenant_id_for_email from danswer.configs.app_configs import CONTROL_PLANE_API_BASE_URL from danswer.configs.app_configs import EXPECTED_API_KEY -from danswer.db.engine import build_connection_string from danswer.db.engine import get_session_with_tenant from danswer.db.engine import get_sqlalchemy_engine from danswer.db.llm import upsert_cloud_embedding_provider @@ -28,13 +21,17 @@ from ee.danswer.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY from ee.danswer.configs.app_configs import COHERE_DEFAULT_API_KEY from ee.danswer.configs.app_configs import OPENAI_DEFAULT_API_KEY +from ee.danswer.server.tenants.schema_management import create_schema_if_not_exists +from ee.danswer.server.tenants.schema_management import drop_schema +from ee.danswer.server.tenants.schema_management import run_alembic_migrations +from ee.danswer.server.tenants.user_mapping import add_users_to_tenant +from ee.danswer.server.tenants.user_mapping import user_owns_a_tenant from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.configs import TENANT_ID_PREFIX from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.enums import EmbeddingProvider - logger = logging.getLogger(__name__) @@ -47,9 +44,8 @@ async def get_or_create_tenant_id(email: str) -> str: tenant_id = get_tenant_id_for_email(email) except exceptions.UserNotExists: # If tenant does not exist and in Multi tenant mode, provision a new tenant - tenant_provisioning_service = TenantProvisioningService() try: - tenant_id = await tenant_provisioning_service.provision_tenant(email) + tenant_id = await create_tenant(email) except Exception as e: logger.error(f"Tenant provisioning failed: {e}") raise HTTPException(status_code=500, detail="Failed to provision tenant.") @@ -62,170 +58,94 @@ async def get_or_create_tenant_id(email: str) -> str: return tenant_id -class TenantProvisioningService: - async def provision_tenant(self, email: str) -> str: - tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4()) # Generate new tenant ID - +async def create_tenant(email: str) -> str: + tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4()) + try: # Provision tenant on data plane - await self._provision_on_data_plane(tenant_id, email) - + await provision_tenant(tenant_id, email) # Notify control plane - await self._notify_control_plane(tenant_id, email) - - return tenant_id - - async def _provision_on_data_plane(self, tenant_id: str, email: str) -> None: - if not MULTI_TENANT: - raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled") - - if user_owns_a_tenant(email): - raise HTTPException( - status_code=409, detail="User already belongs to an organization" - ) - - logger.info(f"Provisioning tenant: {tenant_id}") - token = None - - try: - if not ensure_schema_exists(tenant_id): - logger.info(f"Created schema for tenant {tenant_id}") - else: - logger.info(f"Schema already exists for tenant {tenant_id}") + await notify_control_plane(tenant_id, email) + except Exception as e: + logger.error(f"Tenant provisioning failed: {e}") + await rollback_tenant_provisioning(tenant_id) + raise HTTPException(status_code=500, detail="Failed to provision tenant.") + return tenant_id - token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) - # Await the Alembic migrations - await asyncio.to_thread(run_alembic_migrations, tenant_id) +async def provision_tenant(tenant_id: str, email: str) -> None: + if not MULTI_TENANT: + raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled") - with get_session_with_tenant(tenant_id) as db_session: - setup_danswer(db_session, tenant_id) - configure_default_api_keys(db_session) + if user_owns_a_tenant(email): + raise HTTPException( + status_code=409, detail="User already belongs to an organization" + ) - add_users_to_tenant([email], tenant_id) + logger.info(f"Provisioning tenant: {tenant_id}") + token = None - except Exception as e: - logger.exception(f"Failed to create tenant {tenant_id}") - raise HTTPException( - status_code=500, detail=f"Failed to create tenant: {str(e)}" - ) - finally: - if token is not None: - CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + try: + if not create_schema_if_not_exists(tenant_id): + logger.info(f"Created schema for tenant {tenant_id}") + else: + logger.info(f"Schema already exists for tenant {tenant_id}") - async def _notify_control_plane(self, tenant_id: str, email: str) -> None: - headers = { - "Authorization": f"Bearer {EXPECTED_API_KEY}", # Replace with your control plane API key - "Content-Type": "application/json", - } - payload = {"tenant_id": tenant_id, "email": email} + token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) - async with aiohttp.ClientSession() as session: - async with session.post( - f"{CONTROL_PLANE_API_BASE_URL}/tenants/create", # Replace with your control plane URL - headers=headers, - json=payload, - ) as response: - if response.status != 200: - error_text = await response.text() - logger.error(f"Control plane tenant creation failed: {error_text}") - raise Exception( - f"Failed to create tenant on control plane: {error_text}" - ) + # Await the Alembic migrations + await asyncio.to_thread(run_alembic_migrations, tenant_id) - async def rollback_tenant_provisioning(self, tenant_id: str) -> None: - # Logic to rollback tenant provisioning on data plane - logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}") - try: - # Drop the tenant's schema to rollback provisioning - drop_schema(tenant_id) - # Remove tenant mapping - with Session(get_sqlalchemy_engine()) as db_session: - db_session.query(UserTenantMapping).filter( - UserTenantMapping.tenant_id == tenant_id - ).delete() - db_session.commit() - except Exception as e: - logger.error(f"Failed to rollback tenant provisioning: {e}") + with get_session_with_tenant(tenant_id) as db_session: + setup_danswer(db_session, tenant_id) + configure_default_api_keys(db_session) + add_users_to_tenant([email], tenant_id) -# For now, we're implementing a primitive mapping between users and tenants. -# This function is only used to determine a user's relationship to a tenant upon creation (implying ownership). -def user_owns_a_tenant(email: str) -> bool: - with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: - result = ( - db_session.query(UserTenantMapping) - .filter(UserTenantMapping.email == email) - .first() + except Exception as e: + logger.exception(f"Failed to create tenant {tenant_id}") + raise HTTPException( + status_code=500, detail=f"Failed to create tenant: {str(e)}" ) - return result is not None - - -def add_users_to_tenant(emails: list[str], tenant_id: str) -> None: - with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: - try: - for email in emails: - db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id)) - except Exception: - logger.exception(f"Failed to add users to tenant {tenant_id}") - db_session.commit() - - -def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None: - with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: - try: - mappings_to_delete = ( - db_session.query(UserTenantMapping) - .filter( - UserTenantMapping.email.in_(emails), - UserTenantMapping.tenant_id == tenant_id, + finally: + if token is not None: + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + + +async def notify_control_plane(tenant_id: str, email: str) -> None: + headers = { + "Authorization": f"Bearer {EXPECTED_API_KEY}", + "Content-Type": "application/json", + } + payload = {"tenant_id": tenant_id, "email": email} + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{CONTROL_PLANE_API_BASE_URL}/tenants/create", + headers=headers, + json=payload, + ) as response: + if response.status != 200: + error_text = await response.text() + logger.error(f"Control plane tenant creation failed: {error_text}") + raise Exception( + f"Failed to create tenant on control plane: {error_text}" ) - .all() - ) - for mapping in mappings_to_delete: - db_session.delete(mapping) - - db_session.commit() - except Exception as e: - logger.exception( - f"Failed to remove users from tenant {tenant_id}: {str(e)}" - ) - db_session.rollback() - - -def run_alembic_migrations(schema_name: str) -> None: - logger.info(f"Starting Alembic migrations for schema: {schema_name}") +async def rollback_tenant_provisioning(tenant_id: str) -> None: + # Logic to rollback tenant provisioning on data plane + logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}") try: - current_dir = os.path.dirname(os.path.abspath(__file__)) - root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", "..")) - alembic_ini_path = os.path.join(root_dir, "alembic.ini") - - # Configure Alembic - alembic_cfg = Config(alembic_ini_path) - alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string()) - alembic_cfg.set_main_option( - "script_location", os.path.join(root_dir, "alembic") - ) - - # Ensure that logging isn't broken - alembic_cfg.attributes["configure_logger"] = False - - # Mimic command-line options by adding 'cmd_opts' to the config - alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore - alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore - - # Run migrations programmatically - command.upgrade(alembic_cfg, "head") - - # Run migrations programmatically - logger.info( - f"Alembic migrations completed successfully for schema: {schema_name}" - ) - + # Drop the tenant's schema to rollback provisioning + drop_schema(tenant_id) + # Remove tenant mapping + with Session(get_sqlalchemy_engine()) as db_session: + db_session.query(UserTenantMapping).filter( + UserTenantMapping.tenant_id == tenant_id + ).delete() + db_session.commit() except Exception as e: - logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}") - raise + logger.error(f"Failed to rollback tenant provisioning: {e}") def configure_default_api_keys(db_session: Session) -> None: @@ -249,25 +169,3 @@ def configure_default_api_keys(db_session: Session) -> None: api_key=COHERE_DEFAULT_API_KEY, ) upsert_cloud_embedding_provider(db_session, cloud_embedding_provider) - - -def ensure_schema_exists(tenant_id: str) -> bool: - with Session(get_sqlalchemy_engine()) as db_session: - with db_session.begin(): - result = db_session.execute( - text( - "SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name" - ), - {"schema_name": tenant_id}, - ) - schema_exists = result.scalar() is not None - if not schema_exists: - stmt = CreateSchema(tenant_id) - db_session.execute(stmt) - return True - return False - - -def drop_schema(tenant_id: str) -> None: - with get_sqlalchemy_engine().connect() as connection: - connection.execute(text(f"DROP SCHEMA IF EXISTS {tenant_id} CASCADE")) diff --git a/backend/ee/danswer/server/tenants/schema_management.py b/backend/ee/danswer/server/tenants/schema_management.py new file mode 100644 index 00000000000..9be4e79f984 --- /dev/null +++ b/backend/ee/danswer/server/tenants/schema_management.py @@ -0,0 +1,76 @@ +import logging +import os +from types import SimpleNamespace + +from sqlalchemy import text +from sqlalchemy.orm import Session +from sqlalchemy.schema import CreateSchema + +from alembic import command +from alembic.config import Config +from danswer.db.engine import build_connection_string +from danswer.db.engine import get_sqlalchemy_engine + +logger = logging.getLogger(__name__) + + +def run_alembic_migrations(schema_name: str) -> None: + logger.info(f"Starting Alembic migrations for schema: {schema_name}") + + try: + current_dir = os.path.dirname(os.path.abspath(__file__)) + root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", "..")) + alembic_ini_path = os.path.join(root_dir, "alembic.ini") + + # Configure Alembic + alembic_cfg = Config(alembic_ini_path) + alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string()) + alembic_cfg.set_main_option( + "script_location", os.path.join(root_dir, "alembic") + ) + + # Ensure that logging isn't broken + alembic_cfg.attributes["configure_logger"] = False + + # Mimic command-line options by adding 'cmd_opts' to the config + alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore + alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore + + # Run migrations programmatically + command.upgrade(alembic_cfg, "head") + + # Run migrations programmatically + logger.info( + f"Alembic migrations completed successfully for schema: {schema_name}" + ) + + except Exception as e: + logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}") + raise + + +def create_schema_if_not_exists(tenant_id: str) -> bool: + with Session(get_sqlalchemy_engine()) as db_session: + with db_session.begin(): + result = db_session.execute( + text( + "SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name" + ), + {"schema_name": tenant_id}, + ) + schema_exists = result.scalar() is not None + if not schema_exists: + stmt = CreateSchema(tenant_id) + db_session.execute(stmt) + return True + return False + + +def drop_schema(tenant_id: str) -> None: + if not tenant_id.isidentifier(): + raise ValueError("Invalid tenant_id.") + with get_sqlalchemy_engine().connect() as connection: + connection.execute( + text("DROP SCHEMA IF EXISTS %(schema_name)s CASCADE"), + {"schema_name": tenant_id}, + ) diff --git a/backend/ee/danswer/server/tenants/user_mapping.py b/backend/ee/danswer/server/tenants/user_mapping.py new file mode 100644 index 00000000000..3a3e9befc59 --- /dev/null +++ b/backend/ee/danswer/server/tenants/user_mapping.py @@ -0,0 +1,50 @@ +import logging + +from danswer.db.engine import get_session_with_tenant +from danswer.db.models import UserTenantMapping +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA + +logger = logging.getLogger(__name__) + + +def user_owns_a_tenant(email: str) -> bool: + with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: + result = ( + db_session.query(UserTenantMapping) + .filter(UserTenantMapping.email == email) + .first() + ) + return result is not None + + +def add_users_to_tenant(emails: list[str], tenant_id: str) -> None: + with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: + try: + for email in emails: + db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id)) + except Exception: + logger.exception(f"Failed to add users to tenant {tenant_id}") + db_session.commit() + + +def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None: + with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: + try: + mappings_to_delete = ( + db_session.query(UserTenantMapping) + .filter( + UserTenantMapping.email.in_(emails), + UserTenantMapping.tenant_id == tenant_id, + ) + .all() + ) + + for mapping in mappings_to_delete: + db_session.delete(mapping) + + db_session.commit() + except Exception as e: + logger.exception( + f"Failed to remove users from tenant {tenant_id}: {str(e)}" + ) + db_session.rollback() From 849f4c6e5a5d32ada92d16f0b6f7ab7735b32851 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Mon, 4 Nov 2024 16:31:31 -0800 Subject: [PATCH 08/12] improve import logic --- backend/danswer/auth/users.py | 19 +------------------ backend/danswer/server/manage/users.py | 2 +- backend/ee/danswer/server/tenants/api.py | 2 +- .../ee/danswer/server/tenants/provisioning.py | 2 +- .../ee/danswer/server/tenants/user_mapping.py | 18 ++++++++++++++++++ 5 files changed, 22 insertions(+), 21 deletions(-) diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 6cd15362239..6c162dfb88c 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -48,7 +48,6 @@ from httpx_oauth.oauth2 import BaseOAuth2 from httpx_oauth.oauth2 import OAuth2Token from pydantic import BaseModel -from sqlalchemy import select from sqlalchemy import text from sqlalchemy.orm import attributes from sqlalchemy.orm import Session @@ -83,19 +82,17 @@ from danswer.db.engine import get_async_session_with_tenant from danswer.db.engine import get_session from danswer.db.engine import get_session_with_tenant -from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import AccessToken from danswer.db.models import OAuthAccount from danswer.db.models import User -from danswer.db.models import UserTenantMapping from danswer.db.users import get_user_by_email from danswer.utils.logger import setup_logger from danswer.utils.telemetry import optional_telemetry from danswer.utils.telemetry import RecordType from danswer.utils.variable_functionality import fetch_versioned_implementation from ee.danswer.server.tenants.provisioning import get_or_create_tenant_id +from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email from shared_configs.configs import MULTI_TENANT -from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR logger = setup_logger() @@ -190,20 +187,6 @@ def verify_email_domain(email: str) -> None: ) -def get_tenant_id_for_email(email: str) -> str: - if not MULTI_TENANT: - return POSTGRES_DEFAULT_SCHEMA - # Implement logic to get tenant_id from the mapping table - with Session(get_sqlalchemy_engine()) as db_session: - result = db_session.execute( - select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email) - ) - tenant_id = result.scalar_one_or_none() - if tenant_id is None: - raise exceptions.UserNotExists() - return tenant_id - - def send_user_verification_email( user_email: str, token: str, diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index 819781fc740..5bfbce20902 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -30,7 +30,6 @@ from danswer.auth.users import current_admin_user from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user -from danswer.auth.users import get_tenant_id_for_email from danswer.auth.users import optional_user from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import ENABLE_EMAIL_INVITES @@ -66,6 +65,7 @@ from ee.danswer.db.user_group import remove_curator_status__no_commit from ee.danswer.server.tenants.billing import register_tenant_users from ee.danswer.server.tenants.provisioning import add_users_to_tenant +from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email from ee.danswer.server.tenants.user_mapping import remove_users_from_tenant from shared_configs.configs import MULTI_TENANT diff --git a/backend/ee/danswer/server/tenants/api.py b/backend/ee/danswer/server/tenants/api.py index e6d0e048c83..8c1331c15a6 100644 --- a/backend/ee/danswer/server/tenants/api.py +++ b/backend/ee/danswer/server/tenants/api.py @@ -7,7 +7,6 @@ from danswer.auth.users import auth_backend from danswer.auth.users import current_admin_user from danswer.auth.users import get_jwt_strategy -from danswer.auth.users import get_tenant_id_for_email from danswer.auth.users import User from danswer.configs.app_configs import WEB_DOMAIN from danswer.db.engine import get_session_with_tenant @@ -24,6 +23,7 @@ from ee.danswer.server.tenants.models import BillingInformation from ee.danswer.server.tenants.models import ImpersonateRequest from ee.danswer.server.tenants.models import ProductGatingRequest +from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR stripe.api_key = STRIPE_SECRET_KEY diff --git a/backend/ee/danswer/server/tenants/provisioning.py b/backend/ee/danswer/server/tenants/provisioning.py index 0cb3821bc24..17e0d11cb6e 100644 --- a/backend/ee/danswer/server/tenants/provisioning.py +++ b/backend/ee/danswer/server/tenants/provisioning.py @@ -7,7 +7,6 @@ from sqlalchemy.orm import Session from danswer.auth.users import exceptions -from danswer.auth.users import get_tenant_id_for_email from danswer.configs.app_configs import CONTROL_PLANE_API_BASE_URL from danswer.configs.app_configs import EXPECTED_API_KEY from danswer.db.engine import get_session_with_tenant @@ -25,6 +24,7 @@ from ee.danswer.server.tenants.schema_management import drop_schema from ee.danswer.server.tenants.schema_management import run_alembic_migrations from ee.danswer.server.tenants.user_mapping import add_users_to_tenant +from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email from ee.danswer.server.tenants.user_mapping import user_owns_a_tenant from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA diff --git a/backend/ee/danswer/server/tenants/user_mapping.py b/backend/ee/danswer/server/tenants/user_mapping.py index 3a3e9befc59..6d25eb0d9b0 100644 --- a/backend/ee/danswer/server/tenants/user_mapping.py +++ b/backend/ee/danswer/server/tenants/user_mapping.py @@ -1,12 +1,30 @@ import logging +from fastapi_users import exceptions +from sqlalchemy import select +from sqlalchemy.orm import Session + from danswer.db.engine import get_session_with_tenant +from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import UserTenantMapping from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA + logger = logging.getLogger(__name__) +def get_tenant_id_for_email(email: str) -> str: + # Implement logic to get tenant_id from the mapping table + with Session(get_sqlalchemy_engine()) as db_session: + result = db_session.execute( + select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email) + ) + tenant_id = result.scalar_one_or_none() + if tenant_id is None: + raise exceptions.UserNotExists() + return tenant_id + + def user_owns_a_tenant(email: str) -> bool: with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: result = ( From 0f08e3648463534ff213f48909329f199efec953 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Mon, 4 Nov 2024 16:55:08 -0800 Subject: [PATCH 09/12] ensure proper conditional --- backend/ee/danswer/server/tenants/user_mapping.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/backend/ee/danswer/server/tenants/user_mapping.py b/backend/ee/danswer/server/tenants/user_mapping.py index 6d25eb0d9b0..cf0e5ec5f21 100644 --- a/backend/ee/danswer/server/tenants/user_mapping.py +++ b/backend/ee/danswer/server/tenants/user_mapping.py @@ -7,13 +7,15 @@ from danswer.db.engine import get_session_with_tenant from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import UserTenantMapping +from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA - logger = logging.getLogger(__name__) def get_tenant_id_for_email(email: str) -> str: + if not MULTI_TENANT: + return POSTGRES_DEFAULT_SCHEMA # Implement logic to get tenant_id from the mapping table with Session(get_sqlalchemy_engine()) as db_session: result = db_session.execute( From af88f7e8d07cce5cb980e5fab4e7aa7f17e98e74 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Mon, 4 Nov 2024 17:25:49 -0800 Subject: [PATCH 10/12] minor pydantic update --- backend/ee/danswer/server/tenants/models.py | 5 +++++ backend/ee/danswer/server/tenants/provisioning.py | 9 ++++++--- deployment/cloud_kubernetes/workers/beat.yaml | 4 ++-- deployment/cloud_kubernetes/workers/heavy_worker.yaml | 4 ++-- deployment/cloud_kubernetes/workers/indexing_worker.yaml | 4 ++-- deployment/cloud_kubernetes/workers/light_worker.yaml | 4 ++-- deployment/cloud_kubernetes/workers/primary.yaml | 4 ++-- 7 files changed, 21 insertions(+), 13 deletions(-) diff --git a/backend/ee/danswer/server/tenants/models.py b/backend/ee/danswer/server/tenants/models.py index 2c1fdbecdb3..df24ff6c32d 100644 --- a/backend/ee/danswer/server/tenants/models.py +++ b/backend/ee/danswer/server/tenants/models.py @@ -33,3 +33,8 @@ class CheckoutSessionCreationResponse(BaseModel): class ImpersonateRequest(BaseModel): email: str + + +class TenantCreationPayload(BaseModel): + tenant_id: str + email: str diff --git a/backend/ee/danswer/server/tenants/provisioning.py b/backend/ee/danswer/server/tenants/provisioning.py index 17e0d11cb6e..977d94cce69 100644 --- a/backend/ee/danswer/server/tenants/provisioning.py +++ b/backend/ee/danswer/server/tenants/provisioning.py @@ -8,7 +8,6 @@ from danswer.auth.users import exceptions from danswer.configs.app_configs import CONTROL_PLANE_API_BASE_URL -from danswer.configs.app_configs import EXPECTED_API_KEY from danswer.db.engine import get_session_with_tenant from danswer.db.engine import get_sqlalchemy_engine from danswer.db.llm import upsert_cloud_embedding_provider @@ -20,6 +19,8 @@ from ee.danswer.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY from ee.danswer.configs.app_configs import COHERE_DEFAULT_API_KEY from ee.danswer.configs.app_configs import OPENAI_DEFAULT_API_KEY +from ee.danswer.server.tenants.access import generate_data_plane_token +from ee.danswer.server.tenants.models import TenantCreationPayload from ee.danswer.server.tenants.schema_management import create_schema_if_not_exists from ee.danswer.server.tenants.schema_management import drop_schema from ee.danswer.server.tenants.schema_management import run_alembic_migrations @@ -112,11 +113,13 @@ async def provision_tenant(tenant_id: str, email: str) -> None: async def notify_control_plane(tenant_id: str, email: str) -> None: + logger.info("Fetching billing information") + token = generate_data_plane_token() headers = { - "Authorization": f"Bearer {EXPECTED_API_KEY}", + "Authorization": f"Bearer {token}", "Content-Type": "application/json", } - payload = {"tenant_id": tenant_id, "email": email} + payload = TenantCreationPayload(tenant_id=tenant_id, email=email) async with aiohttp.ClientSession() as session: async with session.post( diff --git a/deployment/cloud_kubernetes/workers/beat.yaml b/deployment/cloud_kubernetes/workers/beat.yaml index a9d053f7295..563dbf10435 100644 --- a/deployment/cloud_kubernetes/workers/beat.yaml +++ b/deployment/cloud_kubernetes/workers/beat.yaml @@ -14,7 +14,7 @@ spec: spec: containers: - name: celery-beat - image: danswer/danswer-backend:v0.11.0-cloud.beta.4 + image: danswer/danswer-backend:v0.11.0-cloud.beta.8 imagePullPolicy: Always command: [ @@ -31,7 +31,7 @@ spec: name: danswer-secrets key: redis_password - name: DANSWER_VERSION - value: "v0.11.0-cloud.beta.4" + value: "v0.11.0-cloud.beta.8" envFrom: - configMapRef: name: env-configmap diff --git a/deployment/cloud_kubernetes/workers/heavy_worker.yaml b/deployment/cloud_kubernetes/workers/heavy_worker.yaml index 682cadee647..d8da6a3d3ae 100644 --- a/deployment/cloud_kubernetes/workers/heavy_worker.yaml +++ b/deployment/cloud_kubernetes/workers/heavy_worker.yaml @@ -14,7 +14,7 @@ spec: spec: containers: - name: celery-worker-heavy - image: danswer/danswer-backend:v0.11.0-cloud.beta.4 + image: danswer/danswer-backend:v0.11.0-cloud.beta.8 imagePullPolicy: Always command: [ @@ -34,7 +34,7 @@ spec: name: danswer-secrets key: redis_password - name: DANSWER_VERSION - value: "v0.11.0-cloud.beta.4" + value: "v0.11.0-cloud.beta.8" envFrom: - configMapRef: name: env-configmap diff --git a/deployment/cloud_kubernetes/workers/indexing_worker.yaml b/deployment/cloud_kubernetes/workers/indexing_worker.yaml index 286cd3036cd..98158f62ef8 100644 --- a/deployment/cloud_kubernetes/workers/indexing_worker.yaml +++ b/deployment/cloud_kubernetes/workers/indexing_worker.yaml @@ -14,7 +14,7 @@ spec: spec: containers: - name: celery-worker-indexing - image: danswer/danswer-backend:v0.11.0-cloud.beta.4 + image: danswer/danswer-backend:v0.11.0-cloud.beta.8 imagePullPolicy: Always command: [ @@ -34,7 +34,7 @@ spec: name: danswer-secrets key: redis_password - name: DANSWER_VERSION - value: "v0.11.0-cloud.beta.4" + value: "v0.11.0-cloud.beta.8" envFrom: - configMapRef: name: env-configmap diff --git a/deployment/cloud_kubernetes/workers/light_worker.yaml b/deployment/cloud_kubernetes/workers/light_worker.yaml index 055fac836e1..2df3b50ea53 100644 --- a/deployment/cloud_kubernetes/workers/light_worker.yaml +++ b/deployment/cloud_kubernetes/workers/light_worker.yaml @@ -14,7 +14,7 @@ spec: spec: containers: - name: celery-worker-light - image: danswer/danswer-backend:v0.11.0-cloud.beta.4 + image: danswer/danswer-backend:v0.11.0-cloud.beta.8 imagePullPolicy: Always command: [ @@ -34,7 +34,7 @@ spec: name: danswer-secrets key: redis_password - name: DANSWER_VERSION - value: "v0.11.0-cloud.beta.4" + value: "v0.11.0-cloud.beta.8" envFrom: - configMapRef: name: env-configmap diff --git a/deployment/cloud_kubernetes/workers/primary.yaml b/deployment/cloud_kubernetes/workers/primary.yaml index 7408e3bfb42..32e34b5cdfc 100644 --- a/deployment/cloud_kubernetes/workers/primary.yaml +++ b/deployment/cloud_kubernetes/workers/primary.yaml @@ -14,7 +14,7 @@ spec: spec: containers: - name: celery-worker-primary - image: danswer/danswer-backend:v0.11.0-cloud.beta.4 + image: danswer/danswer-backend:v0.11.0-cloud.beta.8 imagePullPolicy: Always command: [ @@ -34,7 +34,7 @@ spec: name: danswer-secrets key: redis_password - name: DANSWER_VERSION - value: "v0.11.0-cloud.beta.4" + value: "v0.11.0-cloud.beta.8" envFrom: - configMapRef: name: env-configmap From 2ec3b44ed2d4ce599d758cc1677f2f8e7dcd640d Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Tue, 5 Nov 2024 11:14:09 -0800 Subject: [PATCH 11/12] minor config update --- .../ee/danswer/server/tenants/provisioning.py | 78 ++++++++++++++----- 1 file changed, 57 insertions(+), 21 deletions(-) diff --git a/backend/ee/danswer/server/tenants/provisioning.py b/backend/ee/danswer/server/tenants/provisioning.py index 977d94cce69..e956cf4359c 100644 --- a/backend/ee/danswer/server/tenants/provisioning.py +++ b/backend/ee/danswer/server/tenants/provisioning.py @@ -10,9 +10,14 @@ from danswer.configs.app_configs import CONTROL_PLANE_API_BASE_URL from danswer.db.engine import get_session_with_tenant from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.llm import update_default_provider from danswer.db.llm import upsert_cloud_embedding_provider from danswer.db.llm import upsert_llm_provider from danswer.db.models import UserTenantMapping +from danswer.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES +from danswer.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME +from danswer.llm.llm_provider_options import OPEN_AI_MODEL_NAMES +from danswer.llm.llm_provider_options import OPENAI_PROVIDER_NAME from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest from danswer.server.manage.llm.models import LLMProviderUpsertRequest from danswer.setup import setup_danswer @@ -125,7 +130,7 @@ async def notify_control_plane(tenant_id: str, email: str) -> None: async with session.post( f"{CONTROL_PLANE_API_BASE_URL}/tenants/create", headers=headers, - json=payload, + json=payload.model_dump(), ) as response: if response.status != 200: error_text = await response.text() @@ -152,23 +157,54 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None: def configure_default_api_keys(db_session: Session) -> None: - open_provider = LLMProviderUpsertRequest( - name="OpenAI", - provider="OpenAI", - api_key=OPENAI_DEFAULT_API_KEY, - default_model_name="gpt-4o", - ) - anthropic_provider = LLMProviderUpsertRequest( - name="Anthropic", - provider="Anthropic", - api_key=ANTHROPIC_DEFAULT_API_KEY, - default_model_name="claude-3-5-sonnet-20240620", - ) - upsert_llm_provider(open_provider, db_session) - upsert_llm_provider(anthropic_provider, db_session) - - cloud_embedding_provider = CloudEmbeddingProviderCreationRequest( - provider_type=EmbeddingProvider.COHERE, - api_key=COHERE_DEFAULT_API_KEY, - ) - upsert_cloud_embedding_provider(db_session, cloud_embedding_provider) + if OPENAI_DEFAULT_API_KEY: + open_provider = LLMProviderUpsertRequest( + name="OpenAI", + provider=OPENAI_PROVIDER_NAME, + api_key=OPENAI_DEFAULT_API_KEY, + default_model_name="gpt-4", + fast_default_model_name="gpt-4o-mini", + model_names=OPEN_AI_MODEL_NAMES, + ) + try: + full_provider = upsert_llm_provider(open_provider, db_session) + update_default_provider(full_provider.id, db_session) + except Exception as e: + logger.error(f"Failed to configure OpenAI provider: {e}") + else: + logger.error( + "OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration" + ) + + if ANTHROPIC_DEFAULT_API_KEY: + anthropic_provider = LLMProviderUpsertRequest( + name="Anthropic", + provider=ANTHROPIC_PROVIDER_NAME, + api_key=ANTHROPIC_DEFAULT_API_KEY, + default_model_name="claude-3-5-sonnet-20241022", + fast_default_model_name="claude-3-5-sonnet-20241022", + model_names=ANTHROPIC_MODEL_NAMES, + ) + try: + full_provider = upsert_llm_provider(anthropic_provider, db_session) + update_default_provider(full_provider.id, db_session) + except Exception as e: + logger.error(f"Failed to configure Anthropic provider: {e}") + else: + logger.error( + "ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration" + ) + + if COHERE_DEFAULT_API_KEY: + cloud_embedding_provider = CloudEmbeddingProviderCreationRequest( + provider_type=EmbeddingProvider.COHERE, + api_key=COHERE_DEFAULT_API_KEY, + ) + try: + upsert_cloud_embedding_provider(db_session, cloud_embedding_provider) + except Exception as e: + logger.error(f"Failed to configure Cohere embedding provider: {e}") + else: + logger.error( + "COHERE_DEFAULT_API_KEY not set, skipping Cohere embedding provider configuration" + ) From 2ef37934be6b746d1e96445563eabfadef89f19e Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Tue, 5 Nov 2024 11:15:25 -0800 Subject: [PATCH 12/12] nit --- backend/ee/danswer/auth/tenant.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 backend/ee/danswer/auth/tenant.py diff --git a/backend/ee/danswer/auth/tenant.py b/backend/ee/danswer/auth/tenant.py deleted file mode 100644 index e69de29bb2d..00000000000