Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate tenant upgrades to data plane #3051

Merged
merged 12 commits into from
Nov 8, 2024
Merged
58 changes: 12 additions & 46 deletions backend/danswer/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.orm import attributes
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -83,21 +82,19 @@
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.engine import get_session
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
from danswer.db.models import UserTenantMapping
from danswer.db.users import get_user_by_email
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
from ee.danswer.server.tenants.provisioning import get_or_create_tenant_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No ee imports outside of ee folder

from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR


logger = setup_logger()


Expand Down Expand Up @@ -190,20 +187,6 @@ def verify_email_domain(email: str) -> None:
)


def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
# Implement logic to get tenant_id from the mapping table
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id


def send_user_verification_email(
user_email: str,
token: str,
Expand Down Expand Up @@ -238,19 +221,7 @@ async def create(
safe: bool = False,
request: Optional[Request] = None,
) -> User:
try:
tenant_id = (
get_tenant_id_for_email(user_create.email)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")

if not tenant_id:
raise HTTPException(
status_code=401, detail="User does not belong to an organization"
)
tenant_id = await get_or_create_tenant_id(user_create.email)

async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
Expand All @@ -271,7 +242,7 @@ async def create(
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
user = None

try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
Expand All @@ -292,7 +263,9 @@ async def create(
else:
raise exceptions.UserAlreadyExists()

CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)

return user

async def oauth_callback(
Expand All @@ -308,19 +281,12 @@ async def oauth_callback(
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> models.UOAP:
# Get tenant_id from mapping table
try:
tenant_id = (
get_tenant_id_for_email(account_email)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
tenant_id = await get_or_create_tenant_id(account_email)

if not tenant_id:
raise HTTPException(status_code=401, detail="User not found")

# Proceed with the tenant context
token = None
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
Expand Down Expand Up @@ -371,9 +337,9 @@ async def oauth_callback(
# Explicitly set the Postgres schema for this session to ensure
# OAuth account creation happens in the correct tenant schema
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)

# Add OAuth account
await self.user_db.add_oauth_account(user, oauth_account_dict)
await self.on_after_register(user, request)

else:
Expand Down
8 changes: 4 additions & 4 deletions backend/danswer/background/celery/apps/beat.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ def _update_tenant_tasks(self) -> None:
else:
logger.info("Schedule is up to date, no changes needed")

except (AttributeError, KeyError) as e:
logger.exception(f"Failed to process task configuration: {str(e)}")
except Exception as e:
logger.exception(f"Unexpected error updating tenant tasks: {str(e)}")
except (AttributeError, KeyError):
logger.exception("Failed to process task configuration")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we not want these?

except Exception:
logger.exception("Unexpected error updating tenant tasks")

def _should_update_schedule(
self, current_schedule: dict, new_schedule: dict
Expand Down
2 changes: 2 additions & 0 deletions backend/danswer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,14 @@ def get_application() -> FastAPI:
prefix="/auth",
tags=["auth"],
)

include_router_with_global_prefix_prepended(
application,
fastapi_users.get_register_router(UserRead, UserCreate),
prefix="/auth",
tags=["auth"],
)

include_router_with_global_prefix_prepended(
application,
fastapi_users.get_reset_password_router(),
Expand Down
4 changes: 2 additions & 2 deletions backend/danswer/server/manage/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.auth.users import get_tenant_id_for_email
from danswer.auth.users import optional_user
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import ENABLE_EMAIL_INVITES
Expand Down Expand Up @@ -66,7 +65,8 @@
from ee.danswer.db.user_group import remove_curator_status__no_commit
from ee.danswer.server.tenants.billing import register_tenant_users
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
from ee.danswer.server.tenants.provisioning import remove_users_from_tenant
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
from ee.danswer.server.tenants.user_mapping import remove_users_from_tenant
from shared_configs.configs import MULTI_TENANT

logger = setup_logger()
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/server/query_and_chat/chat_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def stream_generator() -> Generator[str, None, None]:
yield json.dumps(packet) if isinstance(packet, dict) else packet

except Exception as e:
logger.exception(f"Error in chat message streaming: {e}")
logger.exception("Error in chat message streaming")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we changing this?

yield json.dumps({"error": str(e)})

finally:
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/server/query_and_chat/query_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def stream_generator() -> Generator[str, None, None]:
):
yield json.dumps(packet) if isinstance(packet, dict) else packet
except Exception as e:
logger.exception(f"Error in search answer streaming: {e}")
logger.exception("Error in search answer streaming")
yield json.dumps({"error": str(e)})

return StreamingResponse(stream_generator(), media_type="application/json")
56 changes: 1 addition & 55 deletions backend/ee/danswer/server/tenants/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,23 @@
from danswer.auth.users import auth_backend
from danswer.auth.users import current_admin_user
from danswer.auth.users import get_jwt_strategy
from danswer.auth.users import get_tenant_id_for_email
from danswer.auth.users import User
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.db.engine import get_session_with_tenant
from danswer.db.notification import create_notification
from danswer.db.users import get_user_by_email
from danswer.server.settings.store import load_settings
from danswer.server.settings.store import store_settings
from danswer.setup import setup_danswer
from danswer.utils.logger import setup_logger
from ee.danswer.auth.users import current_cloud_superuser
from ee.danswer.configs.app_configs import STRIPE_SECRET_KEY
from ee.danswer.server.tenants.access import control_plane_dep
from ee.danswer.server.tenants.billing import fetch_billing_information
from ee.danswer.server.tenants.billing import fetch_tenant_stripe_information
from ee.danswer.server.tenants.models import BillingInformation
from ee.danswer.server.tenants.models import CreateTenantRequest
from ee.danswer.server.tenants.models import ImpersonateRequest
from ee.danswer.server.tenants.models import ProductGatingRequest
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
from ee.danswer.server.tenants.provisioning import configure_default_api_keys
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
from ee.danswer.server.tenants.provisioning import user_owns_a_tenant
from shared_configs.configs import MULTI_TENANT
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR

stripe.api_key = STRIPE_SECRET_KEY
Expand All @@ -40,52 +32,6 @@
router = APIRouter(prefix="/tenants")


@router.post("/create")
def create_tenant(
create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep)
) -> dict[str, str]:
if not MULTI_TENANT:
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")

tenant_id = create_tenant_request.tenant_id
email = create_tenant_request.initial_admin_email
token = None

if user_owns_a_tenant(email):
raise HTTPException(
status_code=409, detail="User already belongs to an organization"
)

try:
if not ensure_schema_exists(tenant_id):
logger.info(f"Created schema for tenant {tenant_id}")
else:
logger.info(f"Schema already exists for tenant {tenant_id}")

token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
run_alembic_migrations(tenant_id)

with get_session_with_tenant(tenant_id) as db_session:
setup_danswer(db_session, tenant_id)

configure_default_api_keys(db_session)

add_users_to_tenant([email], tenant_id)

return {
"status": "success",
"message": f"Tenant {tenant_id} created successfully",
}
except Exception as e:
logger.exception(f"Failed to create tenant {tenant_id}: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Failed to create tenant: {str(e)}"
)
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)


@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
Expand Down
5 changes: 5 additions & 0 deletions backend/ee/danswer/server/tenants/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,8 @@ class CheckoutSessionCreationResponse(BaseModel):

class ImpersonateRequest(BaseModel):
email: str


class TenantCreationPayload(BaseModel):
tenant_id: str
email: str
Loading
Loading