-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstapi.patch
126 lines (117 loc) · 4.28 KB
/
stapi.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
diff --git a/main.py b/main.py
index 12ec834..6add170 100644
--- a/main.py
+++ b/main.py
@@ -5,6 +5,7 @@ import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer
+from transformers import AutoTokenizer
models: Dict[str, SentenceTransformer] = {}
model_name = os.getenv("MODEL", "all-MiniLM-L6-v2")
@@ -19,12 +20,26 @@ class EmbeddingRequest(BaseModel):
default=model_name,
)
+class TokenizeRequest(BaseModel):
+ input: Union[str, List[str]] = Field(
+ examples=["substratus.ai provides the best LLM tools"]
+ )
+ model: str = Field(
+ examples=[model_name],
+ default=model_name,
+ )
+
class EmbeddingData(BaseModel):
embedding: List[float]
index: int
object: str
+class TokenData(BaseModel):
+ embedding: List[int]
+ index: int
+ object: str
+
class Usage(BaseModel):
prompt_tokens: int
@@ -37,10 +52,17 @@ class EmbeddingResponse(BaseModel):
usage: Usage
object: str
+class TokenizeResponse(BaseModel):
+ data: List[TokenData]
+ model: str
+ usage: Usage
+ object: str
+
@asynccontextmanager
async def lifespan(app: FastAPI):
models[model_name] = SentenceTransformer(model_name, trust_remote_code=True)
+ models[model_name + "_tokenizer"]=AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
yield
@@ -50,18 +72,20 @@ app = FastAPI(lifespan=lifespan)
@app.post("/v1/embeddings")
async def embedding(item: EmbeddingRequest) -> EmbeddingResponse:
model: SentenceTransformer = models[model_name]
+ tokenizer: AutoTokenizer = models[model_name + "_tokenizer"]
if isinstance(item.input, str):
vectors = model.encode(item.input)
- tokens = len(vectors)
+ tokens = tokenizer(item.input)
+ token_count = len(tokens["input_ids"])
return EmbeddingResponse(
data=[EmbeddingData(embedding=vectors, index=0, object="embedding")],
model=model_name,
- usage=Usage(prompt_tokens=tokens, total_tokens=tokens),
+ usage=Usage(prompt_tokens=token_count, total_tokens=token_count),
object="list",
)
if isinstance(item.input, list):
embeddings = []
- tokens = 0
+ total_tokens = 0
for index, text_input in enumerate(item.input):
if not isinstance(text_input, str):
raise HTTPException(
@@ -69,14 +93,16 @@ async def embedding(item: EmbeddingRequest) -> EmbeddingResponse:
detail="input needs to be an array of strings or a string",
)
vectors = model.encode(text_input)
- tokens += len(vectors)
+ tokens = tokenizer(text_input)
+ token_count = len(tokens["input_ids"])
+ total_tokens += token_count
embeddings.append(
EmbeddingData(embedding=vectors, index=index, object="embedding")
)
return EmbeddingResponse(
data=embeddings,
model=model_name,
- usage=Usage(prompt_tokens=tokens, total_tokens=tokens),
+ usage=Usage(prompt_tokens=total_tokens, total_tokens=total_tokens),
object="list",
)
raise HTTPException(
@@ -88,3 +114,25 @@ async def embedding(item: EmbeddingRequest) -> EmbeddingResponse:
@app.get("/healthz")
async def healthz():
return {"status": "ok"}
+
+@app.post("/v1/tokenize")
+async def tokenize(item: TokenizeRequest) -> TokenizeResponse:
+ tokenizer: AutoTokenizer = models[model_name + "_tokenizer"]
+ data = []
+ if isinstance(item.input, str):
+ tokens = tokenizer(item.input)["input_ids"]
+ token_count = len(tokens)
+ data.append(TokenData(embedding=tokens, index=0, object="list"))
+ elif isinstance(item.input, list):
+ token_count = 0
+ for index, input in enumerate(item.input):
+ tokens = tokenizer(input)["input_ids"]
+ token_count += len(tokens)
+ data.append(TokenData(embedding=tokens, index=index, object="list"))
+ else:
+ raise HTTPException(
+ status_code=400, detail="input needs to be an array of strings or a string"
+ )
+ usage = Usage(prompt_tokens=token_count, total_tokens=token_count)
+ resp = TokenizeResponse(data=data, model=model_name, usage=usage, object="tokens")
+ return resp