-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathqabot.py
61 lines (49 loc) · 1.82 KB
/
qabot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from langchain_community.llms import CTransformers
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_community.embeddings import GPT4AllEmbeddings
from langchain_community.vectorstores import FAISS
# Cau hinh
model_file = "models/vinallama-7b-chat_q5_0.gguf"
vector_db_path = "vectorstores/db_faiss"
# Load LLM
def load_llm(model_file):
llm = CTransformers(
model=model_file,
model_type="llama",
max_new_tokens=1024,
temperature=0.01
)
return llm
# Tao prompt template
def creat_prompt(template):
prompt = PromptTemplate(template = template, input_variables=["context", "question"])
return prompt
# Tao simple chain
def create_qa_chain(prompt, llm, db):
llm_chain = RetrievalQA.from_chain_type(
llm = llm,
chain_type= "stuff",
retriever = db.as_retriever(search_kwargs = {"k":3}, max_tokens_limit=1024),
return_source_documents = False,
chain_type_kwargs= {'prompt': prompt}
)
return llm_chain
# Read tu VectorDB
def read_vectors_db():
# Embeding
embedding_model = GPT4AllEmbeddings(model_file="models/all-MiniLM-L6-v2-f16.gguf")
db = FAISS.load_local(vector_db_path, embedding_model)
return db
# Bat dau thu nghiem
db = read_vectors_db()
llm = load_llm(model_file)
#Tao Prompt
template = """<|im_start|>system\nSử dụng thông tin sau đây để trả lời câu hỏi. Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời\n
{context}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant"""
prompt = creat_prompt(template)
llm_chain =create_qa_chain(prompt, llm, db)
# Chay cai chain
question = "Ngày 18/12, SHB đã làm gì?"
response = llm_chain.invoke({"query": question})
print(response)