Skip to content

Commit

Permalink
Feat/0.2.1 (#212)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaojin3616 authored Dec 20, 2023
2 parents b20f259 + fb29df8 commit 2b396bb
Show file tree
Hide file tree
Showing 22 changed files with 328 additions and 105 deletions.
157 changes: 116 additions & 41 deletions src/backend/bisheng/api/v1/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import time
from typing import List, Optional
from uuid import uuid4
from xml.dom.minidom import Document

import requests
from bisheng.api.utils import access_check
from bisheng.api.v1.schemas import ChunkInput, UploadFileResponse
from bisheng.api.v1.schemas import UploadFileResponse
from bisheng.cache.utils import file_download, save_uploaded_file
from bisheng.database.base import get_session
from bisheng.database.models.knowledge import Knowledge, KnowledgeCreate, KnowledgeRead
Expand Down Expand Up @@ -144,6 +145,7 @@ async def process_knowledge(*,
background_tasks.add_task(
addEmbedding,
collection_name=collection_name,
index_name=knowledge.index_name or knowledge.collection_name,
knowledge_id=knowledge_id,
model=knowledge.model,
chunk_size=chunk_size,
Expand Down Expand Up @@ -183,6 +185,7 @@ def create_knowledge(*,
else:
# 默认collectionName
db_knowldge.collection_name = f'col_{int(time.time())}_{str(uuid4())[:8]}'
db_knowldge.index_name = f'col_{int(time.time())}_{str(uuid4())[:8]}'
db_knowldge.user_id = payload.get('user_id')
session.add(db_knowldge)
session.commit()
Expand Down Expand Up @@ -302,8 +305,7 @@ def delete_knowledge(*,
else:
pk = vectore_client.col.query(expr=f'knowledge_id=="{knowledge.id}"',
output_fields=['pk'])
vectore_client.col.delete(f"pk in {[p['pk'] for p in pk]}",
partition_name='knowledge_id')
vectore_client.col.delete(f"pk in {[p['pk'] for p in pk]}")
# 处理 es
# todo

Expand Down Expand Up @@ -374,18 +376,19 @@ def decide_vectorstores(collection_name: str, vector_store: str,
vector_config['ssl_verify'] = eval(vector_config['ssl_verify'])
else:
param = {'collection_name': collection_name, 'embedding': embedding}
vector_config.pop('partition_suffix', '')
param.update(vector_config)
class_obj = import_vectorstore(vector_store)
return instantiate_vectorstore(class_object=class_obj, params=param)


def addEmbedding(collection_name, knowledge_id: int, model: str, chunk_size: int, separator: str,
chunk_overlap: int, file_paths: List[str], knowledge_files: List[KnowledgeFile],
callback: str):
def addEmbedding(collection_name, index_name, knowledge_id: int, model: str, chunk_size: int,
separator: str, chunk_overlap: int, file_paths: List[str],
knowledge_files: List[KnowledgeFile], callback: str):
try:
embeddings = decide_embeddings(model)
vectore_client = decide_vectorstores(collection_name, 'Milvus', embeddings)
es_client = decide_vectorstores(collection_name, 'ElasticKeywordsSearch', embeddings)
es_client = decide_vectorstores(index_name, 'ElasticKeywordsSearch', embeddings)
except Exception as e:
logger.exception(e)

Expand All @@ -408,6 +411,8 @@ def addEmbedding(collection_name, knowledge_id: int, model: str, chunk_size: int
texts, metadatas = _read_chunk_text(path, knowledge_file.file_name, chunk_size,
chunk_overlap, separator)

if len(texts) == 0:
raise ValueError('文件解析为空')
# 溯源必须依赖minio, 后期替换更通用的oss
minio_client.upload_minio(str(db_file.id), path)

Expand Down Expand Up @@ -459,7 +464,12 @@ def _read_chunk_text(input_file, file_name, size, chunk_overlap, separator):
documents = loader.load()
texts = text_splitter.split_documents(documents)
raw_texts = [t.page_content for t in texts]
metadatas = [t.metadata.update({'bbox': '', 'source': file_name}) for t in texts]
metadatas = [{
'bbox': json.dumps({'chunk_bboxes': t.metadata.get('chunk_bboxes', '')}),
'page': t.metadata.get('page'),
'source': file_name,
'extra': ''
} for t in texts]
metadatas = [t.metadata for t in texts]
else:
# 如果文件不是pdf 需要内部转pdf
Expand Down Expand Up @@ -492,14 +502,68 @@ def _read_chunk_text(input_file, file_name, size, chunk_overlap, separator):
'bbox': json.dumps({'chunk_bboxes': t.metadata.get('chunk_bboxes', '')}),
'page': t.metadata.get('chunk_bboxes')[0].get('page'),
'source': t.metadata.get('source', ''),
'extra': {},
'extra': '',
} for t in texts]
return (raw_texts, metadatas)


def file_knowledge(
db_knowledge: Knowledge,
file_path: str,
file_name: str,
metadata: str,
session: Session = Depends(get_session),
):
try:
embeddings = decide_embeddings(db_knowledge.model)
vectore_client = decide_vectorstores(db_knowledge.collection_name, 'Milvus', embeddings)
es_client = decide_vectorstores(db_knowledge.collection_name, 'ElasticKeywordsSearch',
embeddings)
except Exception as e:
logger.exception(e)
separator = ['\n\n', '\n', ' ', '']
chunk_size = 500
chunk_overlap = 50
raw_texts, metadatas = _read_chunk_text(file_path, file_name, chunk_size, chunk_overlap,
separator)
logger.info(f'chunk_split file_name={file_name} size={len(raw_texts)}')
metadata_extra = json.loads(metadata)
# 存储 mysql
db_file = KnowledgeFile(knowledge_id=db_knowledge.id,
file_name=file_name,
status=1,
object_name=metadata_extra.get('url'))
session.add(db_file)
session.flush()

try:
metadata = [{
'file_id': db_file.id,
'knowledge_id': f'{db_knowledge.id}',
'page': metadata.get('page'),
'source': metadata.get('source'),
'bbox': metadata.get('bbox'),
'extra': json.dumps(metadata_extra)
} for metadata in metadatas]
vectore_client.add_texts(texts=raw_texts, metadatas=metadata)

# 存储es
if es_client:
es_client.add_texts(texts=raw_texts, metadatas=metadata)
db_file.status = 2
session.commit()

except Exception as e:
logger.error(e)
setattr(db_file, 'status', 3)
setattr(db_file, 'remark', str(e)[:500])
session.add(db_file)
session.commit()


def text_knowledge(
db_knowledge: Knowledge,
documents: ChunkInput,
documents: List[Document],
session: Session = Depends(get_session),
):
try:
Expand All @@ -510,37 +574,48 @@ def text_knowledge(
except Exception as e:
logger.exception(e)

documents_list = documents.documents
for document in documents_list:
metadata = document.metadata
# 存储 mysql
db_file = KnowledgeFile(knowledge_id=db_knowledge.id,
file_name=metadata.get('source'),
status=1,
object_name=metadata.get('url'))
try:
session.add(db_file)
session.flush()
separator = ['\n\n', '\n', ' ', '']
chunk_size = 500
chunk_overlap = 50

logger.info(
f'chunk_split file_name={db_file.file_name} size={len(document.page_content)}')
metadata = {
'file_id': db_file.id,
'knowledge_id': f'{db_knowledge.id}',
'bbox': '',
'extra': json.dumps(metadata)
}
vectore_client.add_texts(texts=[document.page_content], metadatas=[metadata])
text_splitter = CharacterTextSplitter(separator=separator,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
add_start_index=True)

# 存储es
if es_client:
es_client.add_texts(texts=[document.page_content], metadatas=[metadata])
db_file.status = 2
session.commit()
texts = text_splitter.split_documents(documents)

except Exception as e:
logger.error(e)
setattr(db_file, 'status', 3)
setattr(db_file, 'remark', str(e)[:500])
session.add(db_file)
session.commit()
logger.info(f'chunk_split knowledge_id={db_knowledge.id} size={len(texts)}')

# 存储 mysql
file_name = documents[0].metadata.get('source')
db_file = KnowledgeFile(knowledge_id=db_knowledge.id,
file_name=file_name,
status=1,
object_name=documents[0].metadata.get('url'))
session.add(db_file)
session.flush()

try:
metadata = [{
'file_id': db_file.id,
'knowledge_id': f'{db_knowledge.id}',
'page': doc.metadata.pop('page', ''),
'source': doc.metadata.get('source', ''),
'bbox': doc.metadata.get('bbox', ''),
'extra': json.dumps(doc.metadata)
} for doc in documents]
vectore_client.add_texts(texts=[t.page_content for t in texts], metadatas=metadata)

# 存储es
if es_client:
es_client.add_texts(texts=[t.page_content for t in texts], metadatas=metadata)
db_file.status = 2
session.commit()

except Exception as e:
logger.error(e)
setattr(db_file, 'status', 3)
setattr(db_file, 'remark', str(e)[:500])
session.add(db_file)
session.commit()
1 change: 1 addition & 0 deletions src/backend/bisheng/api/v1/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class ChatMessage(BaseModel):
sender: str = None
receiver: dict = None
liked: int = 0
extra: str = '{}'


class ChatResponse(ChatMessage):
Expand Down
44 changes: 23 additions & 21 deletions src/backend/bisheng/api/v1/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,15 @@ async def login(*,


@router.get('/user/info', response_model=UserRead, status_code=201)
async def get_info(
session: Session = Depends(get_session),
Authorize: AuthJWT = Depends()):
async def get_info(session: Session = Depends(get_session), Authorize: AuthJWT = Depends()):
# check if user already exist
Authorize.jwt_required()
payload = json.loads(Authorize.get_jwt_subject())
try:
user_id = payload.get('user_id')
user = session.get(User, user_id)
# 查询角色
db_user_role = session.exec(
select(UserRole).where(UserRole.user_id == user_id)).all()
db_user_role = session.exec(select(UserRole).where(UserRole.user_id == user_id)).all()
if next((user_role for user_role in db_user_role if user_role.role_id == 1), None):
# 是管理员,忽略其他的角色
role = 'admin'
Expand Down Expand Up @@ -265,7 +262,7 @@ async def user_addrole(*,
if 'admin' != json.loads(Authorize.get_jwt_subject()).get('role'):
raise HTTPException(status_code=500, detail='无设置权限')

db_role = session.exec(select(UserRole).where(UserRole.user_id == userRole.user_id,)).all()
db_role = session.exec(select(UserRole).where(UserRole.user_id == userRole.user_id, )).all()
role_ids = {role.role_id for role in db_role}
for role_id in userRole.role_id:
if role_id not in role_ids:
Expand Down Expand Up @@ -401,15 +398,17 @@ async def knowledge_list(*,
user_dict = {user.user_id: user.user_name for user in db_users}

return {
'msg': 'success',
'msg':
'success',
'data': [{
'name': access[0].name,
'user_name': user_dict.get(access[0].user_id),
'user_id': access[0].user_id,
'update_time': access[0].update_time,
'id': access[0].id
} for access in db_role_access],
'total': total_count
'total':
total_count
}


Expand All @@ -425,11 +424,12 @@ async def flow_list(*,
if 'admin' != json.loads(Authorize.get_jwt_subject()).get('role'):
raise HTTPException(status_code=500, detail='无查看权限')

statment = select(Flow, RoleAccess).join(RoleAccess,
and_(RoleAccess.role_id == role_id,
RoleAccess.type == AccessType.FLOW.value,
RoleAccess.third_id == Flow.id),
isouter=True)
statment = select(Flow.id, Flow.name, Flow.user_id, Flow.update_time,
RoleAccess).join(RoleAccess,
and_(RoleAccess.role_id == role_id,
RoleAccess.type == AccessType.FLOW.value,
RoleAccess.third_id == Flow.id),
isouter=True)
count_sql = select(func.count(Flow.id))

if name:
Expand All @@ -444,19 +444,21 @@ async def flow_list(*,
total_count = session.scalar(count_sql)

# 补充用户名
user_ids = [access[0].user_id for access in db_role_access]
user_ids = [access[2] for access in db_role_access]
db_users = session.query(User).filter(User.user_id.in_(user_ids)).all()
user_dict = {user.user_id: user.user_name for user in db_users}
return {
'msg': 'success',
'msg':
'success',
'data': [{
'name': access[0].name,
'user_name': user_dict.get(access[0].user_id),
'user_id': access[0].user_id,
'update_time': access[0].update_time,
'id': access[0].id
'name': access[1],
'user_name': user_dict.get(access[2]),
'user_id': access[2],
'update_time': access[3],
'id': access[0]
} for access in db_role_access],
'total': total_count
'total':
total_count
}


Expand Down
Loading

0 comments on commit 2b396bb

Please sign in to comment.