Skip to content

Commit

Permalink
Merge pull request #42 from mindvalley/chore/merge-upstream-2024061901
Browse files Browse the repository at this point in the history
chore/merge upstream 2024061901
  • Loading branch information
onimsha authored Jun 19, 2024
2 parents 79d3bce + f93823d commit 5757a58
Show file tree
Hide file tree
Showing 88 changed files with 3,761 additions and 1,026 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
.idea
/deployment/data/nginx/app.conf
.vscode/launch.json
*.sw?
4 changes: 4 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ For convenience here's a command for it:
python -m venv .venv
source .venv/bin/activate
```

--> Note that this virtual environment MUST NOT be set up WITHIN the danswer
directory

_For Windows, activate the virtual environment using Command Prompt:_
```bash
.venv\Scripts\activate
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Add support for custom tools
Revision ID: 48d14957fe80
Revises: b85f02ec1308
Create Date: 2024-06-09 14:58:19.946509
"""
from alembic import op
import fastapi_users_db_sqlalchemy
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "48d14957fe80"
down_revision = "b85f02ec1308"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column(
"tool",
sa.Column(
"openapi_schema",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
op.add_column(
"tool",
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
)
op.create_foreign_key("tool_user_fk", "tool", "user", ["user_id"], ["id"])

op.create_table(
"tool_call",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("tool_id", sa.Integer(), nullable=False),
sa.Column("tool_name", sa.String(), nullable=False),
sa.Column(
"tool_arguments", postgresql.JSONB(astext_type=sa.Text()), nullable=False
),
sa.Column(
"tool_result", postgresql.JSONB(astext_type=sa.Text()), nullable=False
),
sa.Column(
"message_id", sa.Integer(), sa.ForeignKey("chat_message.id"), nullable=False
),
)


def downgrade() -> None:
op.drop_table("tool_call")

op.drop_constraint("tool_user_fk", "tool", type_="foreignkey")
op.drop_column("tool", "user_id")
op.drop_column("tool", "openapi_schema")
22 changes: 22 additions & 0 deletions backend/alembic/versions/e209dc5a8156_added_prune_frequency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""added-prune-frequency
Revision ID: e209dc5a8156
Revises: 48d14957fe80
Create Date: 2024-06-16 16:02:35.273231
"""
from alembic import op
import sqlalchemy as sa

revision = "e209dc5a8156"
down_revision = "48d14957fe80"
branch_labels = None # type: ignore
depends_on = None # type: ignore


def upgrade() -> None:
op.add_column("connector", sa.Column("prune_freq", sa.Integer(), nullable=True))


def downgrade() -> None:
op.drop_column("connector", "prune_freq")
21 changes: 21 additions & 0 deletions backend/danswer/auth/invited_users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import cast

from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.dynamic_configs.interface import JSON_ro

USER_STORE_KEY = "INVITED_USERS"


def get_invited_users() -> list[str]:
try:
store = get_dynamic_config_store()
return cast(list, store.load(USER_STORE_KEY))
except ConfigNotFoundError:
return list()


def write_invited_users(emails: list[str]) -> int:
store = get_dynamic_config_store()
store.store(USER_STORE_KEY, cast(JSON_ro, emails))
return len(emails)
6 changes: 6 additions & 0 deletions backend/danswer/auth/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ class UserRole(str, Enum):
ADMIN = "admin"


class UserStatus(str, Enum):
LIVE = "live"
INVITED = "invited"
DEACTIVATED = "deactivated"


class UserRead(schemas.BaseUser[uuid.UUID]):
role: UserRole

Expand Down
19 changes: 2 additions & 17 deletions backend/danswer/auth/users.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import smtplib
import uuid
from collections.abc import AsyncGenerator
Expand Down Expand Up @@ -27,6 +26,7 @@
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
from sqlalchemy.orm import Session

from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
from danswer.configs.app_configs import AUTH_TYPE
Expand Down Expand Up @@ -59,9 +59,6 @@

logger = setup_logger()

USER_WHITELIST_FILE = "/home/danswer_whitelist.txt"
_user_whitelist: list[str] | None = None


def verify_auth_setting() -> None:
if AUTH_TYPE not in [AuthType.DISABLED, AuthType.BASIC, AuthType.GOOGLE_OAUTH]:
Expand Down Expand Up @@ -92,20 +89,8 @@ def user_needs_to_be_verified() -> bool:
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION


def get_user_whitelist() -> list[str]:
global _user_whitelist
if _user_whitelist is None:
if os.path.exists(USER_WHITELIST_FILE):
with open(USER_WHITELIST_FILE, "r") as file:
_user_whitelist = [line.strip() for line in file]
else:
_user_whitelist = []

return _user_whitelist


def verify_email_in_whitelist(email: str) -> None:
whitelist = get_user_whitelist()
whitelist = get_invited_users()
if (whitelist and email not in whitelist) or not email:
raise PermissionError("User not on allowed user whitelist")

Expand Down
133 changes: 116 additions & 17 deletions backend/danswer/background/celery/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,22 @@
from celery import Celery # type: ignore
from sqlalchemy.orm import Session

from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.celery_utils import should_prune_cc_pair
from danswer.background.celery.celery_utils import should_sync_doc_set
from danswer.background.connector_deletion import delete_connector_credential_pair
from danswer.background.connector_deletion import delete_connector_credential_pair_batch
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import fetch_document_sets
Expand All @@ -22,8 +31,6 @@
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import SYNC_DB_API
from danswer.db.models import DocumentSet
from danswer.db.tasks import check_live_task_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import UpdateRequest
Expand Down Expand Up @@ -90,6 +97,74 @@ def cleanup_connector_credential_pair_task(
raise e


@build_celery_task_wrapper(name_cc_prune_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def prune_documents_task(connector_id: int, credential_id: int) -> None:
"""connector pruning task. For a cc pair, this task pulls all docuement IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
with Session(get_sqlalchemy_engine()) as db_session:
try:
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)

if not cc_pair:
logger.warning(f"ccpair not found for {connector_id} {credential_id}")
return

runnable_connector = instantiate_connector(
cc_pair.connector.source,
InputType.PRUNE,
cc_pair.connector.connector_specific_config,
cc_pair.credential,
db_session,
)

all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
runnable_connector
)

all_indexed_document_ids = {
doc.id
for doc in get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
}

doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)

curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)

if len(doc_ids_to_remove) == 0:
logger.info(
f"No docs to prune from {cc_pair.connector.source} connector"
)
return

logger.info(
f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector"
)
delete_connector_credential_pair_batch(
document_ids=doc_ids_to_remove,
connector_id=connector_id,
credential_id=credential_id,
document_index=document_index,
)
except Exception as e:
logger.exception(
f"Failed to run pruning for connector id {connector_id} due to {e}"
)
raise e


@build_celery_task_wrapper(name_document_set_sync_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_document_set_task(document_set_id: int) -> None:
Expand Down Expand Up @@ -188,32 +263,48 @@ def _sync_document_batch(document_ids: list[str], db_session: Session) -> None:
soft_time_limit=JOB_TIMEOUT,
)
def check_for_document_sets_sync_task() -> None:
"""Runs periodically to check if any document sets are out of sync
Creates a task to sync the set if needed"""
"""Runs periodically to check if any sync tasks should be run and adds them
to the queue"""
with Session(get_sqlalchemy_engine()) as db_session:
# check if any document sets are not synced
document_set_info = fetch_document_sets(
user_id=None, db_session=db_session, include_outdated=True
)
for document_set, _ in document_set_info:
if not document_set.is_up_to_date:
task_name = name_document_set_sync_task(document_set.id)
latest_sync = get_latest_task(task_name, db_session)

if latest_sync and check_live_task_not_timed_out(
latest_sync, db_session
):
logger.info(
f"Document set '{document_set.id}' is already syncing. Skipping."
)
continue

logger.info(f"Document set {document_set.id} syncing now!")
if should_sync_doc_set(document_set, db_session):
logger.info(f"Syncing the {document_set.name} document set")
sync_document_set_task.apply_async(
kwargs=dict(document_set_id=document_set.id),
)


@celery_app.task(
name="check_for_prune_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_prune_task() -> None:
"""Runs periodically to check if any prune tasks should be run and adds them
to the queue"""

with Session(get_sqlalchemy_engine()) as db_session:
all_cc_pairs = get_connector_credential_pairs(db_session)

for cc_pair in all_cc_pairs:
if should_prune_cc_pair(
connector=cc_pair.connector,
credential=cc_pair.credential,
db_session=db_session,
):
logger.info(f"Pruning the {cc_pair.connector.name} connector")

prune_documents_task.apply_async(
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
)


#####
# Celery Beat (Periodic Tasks) Settings
#####
Expand All @@ -223,3 +314,11 @@ def check_for_document_sets_sync_task() -> None:
"schedule": timedelta(seconds=5),
},
}
celery_app.conf.beat_schedule.update(
{
"check-for-prune": {
"task": "check_for_prune_task",
"schedule": timedelta(seconds=5),
},
}
)
Loading

0 comments on commit 5757a58

Please sign in to comment.