-
Notifications
You must be signed in to change notification settings - Fork 0
/
tasks.py
82 lines (63 loc) · 2.42 KB
/
tasks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from invoke import task
from quinn_gpt.scrapers import DocsScraper
from quinn_gpt.db import QuinnDB
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.document_loaders import UnstructuredHTMLLoader
from tqdm import tqdm
import os
VERSION = '5.1'
PERSIST_DIR = f'./chromadb/quinn_gpt'
qdb = QuinnDB('quinn_gpt')
scraper = DocsScraper(VERSION, qdb)
@task
def run(c, url):
scraper.scrape_url(url, VERSION)
@task
def run_all(c):
start_url = f'https://docs.unrealengine.com/{VERSION}/en-US/'
scraper.crawl_site(start_url)
@task
def cache_to_chroma(c, chunk_size=400, reset=True):
# load the document and split it into chunks
chroma = Chroma(persist_directory=PERSIST_DIR, embedding_function=SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2"))
chroma.persist()
for filename in tqdm(os.listdir('.cache')):
loader = UnstructuredHTMLLoader(".cache/"+filename)
documents = loader.load()
# split it into chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=int(chunk_size), chunk_overlap=0)
docs = text_splitter.split_documents(documents)
# create the open-source embedding function
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
# load it into Chroma
chroma.add_documents(docs)
@task
def query(c, query, k=5):
chroma = Chroma(persist_directory=PERSIST_DIR, embedding_function=SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2"))
results = chroma.similarity_search(query, k=k)
for result in results:
print(result.page_content)
@task
def estimate_cost(c):
# Loops through all files in .cache and estimates the cost of embedding them
total_cost = 0
total_words = 0
total_tokens = 0
for filename in os.listdir('.cache'):
with open(f'.cache/{filename}', 'r') as f:
text = f.read()
words = len(text.split())
tokens = words*1.3
total_tokens += tokens
total_words += words
cost = tokens / 1000 *.0001
total_cost += cost
print(f'{total_words} words, ${total_cost}')
@task
def test(c):
c.run('pytest ./tests --cov=quinn_gpt --cov-report=term-missing')
@task
def remove_pound(c):
qdb.remove_hashed_urls()