Skip to content

Commit

Permalink
v6.11
Browse files Browse the repository at this point in the history
  • Loading branch information
BBC-Esq authored Dec 8, 2024
1 parent 4baf306 commit 5b6ace9
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 121 deletions.
14 changes: 11 additions & 3 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,16 @@
'type': 'vector',
'precision': 'float16'
},
{
'name': 'sentence-t5-xxl',
'dimensions': 768,
'max_sequence': 256,
'size_mb': 9230,
'repo_id': 'sentence-transformers/sentence-t5-xxl',
'cache_dir': 'sentence-transformers--sentence-t5-xxl',
'type': 'vector',
'precision': 'float16'
},
],
'thenlper': [
{
Expand Down Expand Up @@ -929,7 +939,7 @@
"</table>"
"</body>"
"</html>"
),
),
"VECTOR_MODEL_SELECT": "Choose a vector model to download.",
"VECTOR_MODEL_SIZE": "Size on disk.",
"VISION_MODEL": "Select vision model for image processing. Test before bulk processing.",
Expand All @@ -938,8 +948,6 @@
"WHISPER_MODEL_SELECT": "Distil models use ~ 70% VRAM of their non-Distil equivalents with little quality loss."
}



scrape_documentation = {
"Accelerate 0.34.2": {
"URL": "https://huggingface.co/docs/accelerate/v0.34.2/en/",
Expand Down
33 changes: 29 additions & 4 deletions src/database_interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from typing import Optional
import threading
import re

import sqlite3
import torch
Expand All @@ -23,7 +24,32 @@
def create_vector_db_in_process(database_name):
create_vector_db = CreateVectorDB(database_name=database_name)
create_vector_db.run()


def process_chunks_only_query(database_name, query, result_queue):
try:
query_db = QueryVectorDB(database_name)
contexts, metadata_list = query_db.search(query)

formatted_contexts = []
for index, (context, metadata) in enumerate(zip(contexts, metadata_list), start=1):
file_name = metadata.get('file_name', 'Unknown')
cleaned_context = re.sub(r'\n[ \t]+\n', '\n\n', context)
cleaned_context = re.sub(r'\n\s*\n\s*\n*', '\n\n', cleaned_context.strip())
formatted_context = (
f"{'-'*80}\n"
f"CONTEXT {index} | {file_name}\n"
f"{'-'*80}\n"
f"{cleaned_context}\n"
)
formatted_contexts.append(formatted_context)

result_queue.put("\n".join(formatted_contexts))
except Exception as e:
result_queue.put(f"Error querying database: {str(e)}")
finally:
if 'query_db' in locals():
query_db.cleanup()

class CreateVectorDB:
def __init__(self, database_name):
self.ROOT_DIRECTORY = Path(__file__).resolve().parent
Expand Down Expand Up @@ -57,6 +83,7 @@ def initialize_vector_model(self, embedding_model_name, config_data):
encode_kwargs['batch_size'] = 2
else:
batch_size_mapping = {
't5-xxl': 2,
't5-xl': 2,
't5-large': 4,
'instructor-xl': 2,
Expand Down Expand Up @@ -354,7 +381,7 @@ def get_instance(cls, selected_database):
with cls._instance_lock:
if cls._instance is not None:
if cls._instance.selected_database != selected_database:
logging.info(f"Database changed from {cls._instance.selected_database} to {selected_database}")
print(f"Database changed from {cls._instance.selected_database} to {selected_database}")
cls._instance.cleanup()
cls._instance = None
else:
Expand All @@ -366,7 +393,6 @@ def get_instance(cls, selected_database):
return cls._instance

def load_configuration(self):
"""Load configuration from config.yaml file"""
config_path = Path(__file__).resolve().parent / 'config.yaml'
try:
with open(config_path, 'r', encoding='utf-8') as file:
Expand Down Expand Up @@ -439,7 +465,6 @@ def search(self, query, k: Optional[int] = None, score_threshold: Optional[float
logging.info(f"Initializing database connection for {self.selected_database}")
self.db = self.initialize_database()

# The rest of your existing search method remains unchanged
self.config = self.load_configuration()
document_types = self.config['database'].get('document_types', '')
search_filter = {'document_type': document_types} if document_types else {}
Expand Down
4 changes: 1 addition & 3 deletions src/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,16 @@ def set_cuda_paths():
cuda_path = nvidia_base_path / 'cuda_runtime' / 'bin'
cublas_path = nvidia_base_path / 'cublas' / 'bin'
cudnn_path = nvidia_base_path / 'cudnn' / 'bin'
# nvcc_path = nvidia_base_path / 'cuda_nvcc' / 'bin'
nvrtc_path = nvidia_base_path / 'cuda_nvrtc' / 'bin'

paths_to_add = [
str(cuda_path), # CUDA runtime
str(cublas_path), # cuBLAS
str(cudnn_path), # cuDNN
# str(nvcc_path), # NVIDIA CUDA compiler
str(nvrtc_path), # NVIDIA runtime compiler
]

env_vars = ['CUDA_PATH', 'CUDA_PATH_V12_1', 'PATH']
env_vars = ['CUDA_PATH', 'PATH']

for env_var in env_vars:
current_value = os.environ.get(env_var, '')
Expand Down
43 changes: 23 additions & 20 deletions src/gui_tabs_database_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,45 @@
from module_voice_recorder import VoiceRecorder
from utilities import check_preconditions_for_submit_question, my_cprint
from constants import TOOLTIPS
from database_interactions import QueryVectorDB
from database_interactions import QueryVectorDB, process_chunks_only_query

current_dir = Path(__file__).resolve().parent
input_text_file = str(current_dir / 'chat_history.txt')


class DatabaseQueryThread(QThread):
class ChunksOnlyThread(QThread):
chunks_ready = Signal(str)

def __init__(self, query, database_name):
super().__init__()
self.query = query
self.database_name = database_name
self.process = None

def run(self):
try:
query_db = QueryVectorDB(self.database_name)
contexts, metadata_list = query_db.search(self.query)
formatted_chunks = self.format_chunks(contexts, metadata_list)
self.chunks_ready.emit(formatted_chunks)
result_queue = multiprocessing.Queue()

self.process = multiprocessing.Process(
target=process_chunks_only_query,
args=(self.database_name, self.query, result_queue)
)
self.process.start()

result = result_queue.get()
self.chunks_ready.emit(result)

self.process.join()
self.process = None

except Exception as e:
logging.exception(f"Error in database query thread: {e}")
logging.exception(f"Error in chunks only thread: {e}")
self.chunks_ready.emit(f"Error querying database: {str(e)}")

@staticmethod
def format_chunks(contexts, metadata_list):
formatted_contexts = []
for index, (context, metadata) in enumerate(zip(contexts, metadata_list), start=1):
file_name = metadata.get('file_name', 'Unknown')
formatted_context = (
f"---------- Context {index} | From File: {file_name} ----------\n"
f"{context}\n"
)
formatted_contexts.append(formatted_context)
return "\n".join(formatted_contexts)
def stop(self):
if self.process and self.process.is_alive():
self.process.terminate()
self.process.join()

def run_tts_in_process(config_path, input_text_file):
from module_tts import run_tts # Import here to avoid potential circular imports
Expand Down Expand Up @@ -250,9 +254,8 @@ def on_submit_button_clicked(self):
chunks_only = self.chunks_only_checkbox.isChecked()

selected_database = self.database_pulldown.currentText()

if chunks_only: # only get chunks
self.database_query_thread = DatabaseQueryThread(user_question, selected_database)
self.database_query_thread = ChunksOnlyThread(user_question, selected_database)
self.database_query_thread.chunks_ready.connect(self.display_chunks)
self.database_query_thread.finished.connect(self.on_database_query_finished)
self.database_query_thread.start()
Expand Down
7 changes: 6 additions & 1 deletion src/gui_tabs_manage_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self):

self.database_info_layout = QHBoxLayout()
self.database_info_label = QLabel("No database selected.")
self.database_info_label.setTextFormat(Qt.RichText)
self.database_info_layout.addWidget(self.database_info_label)
self.layout.addLayout(self.database_info_layout)

Expand Down Expand Up @@ -122,7 +123,11 @@ def update_table_view_and_info_label(self, index):
model_name = Path(model_path).name
chunk_size = db_config.get('chunk_size', '')
chunk_overlap = db_config.get('chunk_overlap', '')
info_text = f"DB name: \"{selected_database}\" | Created with \"{model_name}\" | Chunk size/overlap = {chunk_size} / {chunk_overlap}."
info_text = (f'<span style="color: #4CAF50;"><b>Name:</b></span> "{selected_database}" '
f'<span style="color: #888;">|</span> '
f'<span style="color: #2196F3;"><b>Model:</b></span> "{model_name}" '
f'<span style="color: #888;">|</span> '
f'<span style="color: #FF9800;"><b>Chunk size/overlap:</b></span> {chunk_size} / {chunk_overlap}')
self.database_info_label.setText(info_text)
else:
self.database_info_label.setText("Configuration missing.")
Expand Down
17 changes: 11 additions & 6 deletions src/gui_tabs_tools_scrape.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@ def init_ui(self):

main_layout.addLayout(hbox)

self.status_label = QLabel("Pages scraped: 0")
self.status_label = QLabel()
self.status_label.setTextFormat(Qt.RichText)
self.status_label.setOpenExternalLinks(False)
self.status_label.setToolTip("Click the links to open the folder containing scraped data.")
self.status_label.linkActivated.connect(self.open_folder)
self.status_label.setAlignment(Qt.AlignLeft | Qt.AlignVCenter)
self.status_label.setText('<span style="color: #4CAF50;"><b>Pages scraped:</b></span> 0')
main_layout.addWidget(self.status_label)

main_layout.addStretch()
Expand Down Expand Up @@ -115,16 +116,20 @@ def start_scraping(self):
self.thread.start()

def update_status(self, status):
self.status_label.setText(f'Pages scraped: {status} <a href="open_folder">Open Folder</a>')
self.status_label.setText(
f'<span style="color: #4CAF50;"><b>Pages scraped:</b></span> {status} '
f'<span style="color: #2196F3;"><a href="open_folder" style="color: #2196F3;">Open Folder</a></span>'
)

def scraping_finished(self):
self.scrape_button.setEnabled(True)
selected_doc = self.doc_combo.currentText()
final_count = len([f for f in os.listdir(self.current_folder) if f.endswith('.html')])
self.status_label.setText(f'Scraping {selected_doc} completed. Pages scraped: {final_count} <a href="open_folder">Open Folder</a>')
self.populate_combo_box()
if hasattr(self.worker, 'cleanup'):
self.worker.cleanup()
self.status_label.setText(
f'<span style="color: #FF9800;"><b>Scraping {selected_doc} completed.</b></span> '
f'<span style="color: #4CAF50;"><b>Pages scraped:</b></span> {final_count} '
f'<span style="color: #2196F3;"><a href="open_folder" style="color: #2196F3;">Open Folder</a></span>'
)

def open_folder(self, link):
if link == "open_folder":
Expand Down
5 changes: 2 additions & 3 deletions src/module_ask_jeeves.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import subprocess
import sys
import platform
import signal
import os
Expand All @@ -12,8 +11,8 @@
import sseclient
from huggingface_hub import snapshot_download
from PySide6.QtWidgets import (
QMainWindow, QWidget, QVBoxLayout, QHBoxLayout,
QTextEdit, QLineEdit, QCommandLinkButton, QMessageBox,
QMainWindow, QWidget, QVBoxLayout,
QTextEdit, QLineEdit, QMessageBox,
QLabel, QApplication, QProgressDialog
)
from PySide6.QtCore import QThread, Signal, QObject, Qt
Expand Down
5 changes: 3 additions & 2 deletions src/module_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,8 @@ def __init__(self):
def initialize_model_and_processor(self):
repository_id = "suno/bark" if self.config['size'] == 'normal' else f"suno/bark-{self.config['size']}"

# processor
self.processor = AutoProcessor.from_pretrained(repository_id, cache_dir=CACHE_DIR)

# model
self.model = BarkModel.from_pretrained(
repository_id,
torch_dtype=torch.float16,
Expand All @@ -133,6 +131,9 @@ def process_text_to_audio(self, sentences):
**inputs,
use_cache=True,
do_sample=True,
# temperature=0.2,
# top_k=50,
# top_p=0.95,
pad_token_id=0,
)

Expand Down
Loading

0 comments on commit 5b6ace9

Please sign in to comment.