forked from openai/chatgpt-retrieval-plugin
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
122 lines (103 loc) · 3.65 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import uvicorn
import pinecone
import openai
from fastapi import FastAPI, File, HTTPException, Depends, Body, UploadFile
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from typing import Optional
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
PINECONE_API_ENV = os.environ.get("PINECONE_API_ENV")
openai.api_key = OPENAI_API_KEY
bearer_scheme = HTTPBearer()
#BEARER_TOKEN = os.environ.get("BEARER_TOKEN")
#assert BEARER_TOKEN is not None
# initialize pinecone
pinecone.init(
api_key=PINECONE_API_KEY, # find at app.pinecone.io
environment=PINECONE_API_ENV # next to api key in console
)
index_name = os.environ.get("PINECONE_INDEX")
# connect to index
index = pinecone.Index(index_name)
limit = 3750
def complete(prompt):
# query text-davinci-003
res=openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": prompt},
],
temperature=0.5,
max_tokens=400,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
stop=None)
return res['choices'][0]['message']['content'].strip()
def retrieve(query):
res = openai.Embedding.create(
input=[query],
engine="text-embedding-ada-002"
)
# retrieve from Pinecone
xq = res['data'][0]['embedding']
# get relevant contexts
res = index.query(xq, top_k=3, include_metadata=True)
contexts = [
x['metadata']['text'] for x in res['matches']
]
# build our prompt with the retrieved contexts included
prompt_start = (
"Answer the question based on the context below.\n\n"+
"Context:\n"
)
prompt_end = (
f"\n\nQuestion: {query}\nAnswer:"
)
# append contexts until hitting limit
for i in range(1, len(contexts)):
if len("\n\n---\n\n".join(contexts[:i])) >= limit:
prompt = (
prompt_start +
"\n\n---\n\n".join(contexts[:i-1]) +
prompt_end
)
break
elif i == len(contexts)-1:
prompt = (
prompt_start +
"\n\n---\n\n".join(contexts) +
prompt_end
)
return prompt
def validate_token(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)):
if credentials.scheme != "Bearer" or credentials.credentials != BEARER_TOKEN:
raise HTTPException(status_code=401, detail="Invalid or missing token")
return credentials
#app = FastAPI(dependencies=[Depends(validate_token)])
app = FastAPI()
# app.mount("/.well-known", StaticFiles(directory=".well-known"), name="static")
# Create a sub-application, in order to access just the query endpoint in an OpenAPI schema, found at http://0.0.0.0:8000/sub/openapi.json when the app is running locally
sub_app = FastAPI(
title="Retrieval Plugin API",
description="A retrieval API for querying and filtering documents based on natural language queries and metadata",
version="1.0.0",
servers=[{"url": "https://your-app-url.com"}]#,
# dependencies=[Depends(validate_token)],
)
app.mount("/sub", sub_app)
class SearchQuery(BaseModel):
query: str
@app.get("/search", response_model=SearchQuery)
async def search(query: Optional[str] = None):
if query:
results = retrieve(query)
answer = complete(results)
return {"query": answer}
else:
raise HTTPException(status_code=400, detail="Please provide a query.")
def start():
uvicorn.run("server.main:app", host="0.0.0.0", port=8000, reload=True)