-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict-vllm.py
80 lines (69 loc) · 2.43 KB
/
predict-vllm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
from google.cloud import aiplatform
PROJECT_ID ="<PROJECT_ID>"
REGION = "us-central1"
ENDPOINT_ID="ENDPOINT_ID"
print(f"PROJECT_ID: {PROJECT_ID}")
print(f"REGION: {REGION}")
print(f"ENDPOINT_ID: {ENDPOINT_ID}")
import logging
logging.basicConfig(level=logging.DEBUG)
aiplatform.init(project=PROJECT_ID, location=REGION)
client = aiplatform.gapic.PredictionServiceClient(client_options={"api_endpoint": f"{REGION}-aiplatform.googleapis.com"})
# Define endpoint dictionary (you'll need to adjust based on your deployed endpoints)
endpoints = {
"vllm_gpu": client.endpoint_path(
project=PROJECT_ID, location=REGION, endpoint=ENDPOINT_ID
)
}
print(endpoints)
def predict_vllm(
prompt: str,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
raw_response: bool,
lora_weight: str = "",
):
# Parameters for inference.
logging.debug(f"Sending prediction request to endpoint: {endpoints['vllm_gpu']}")
lora_weight = "<Lora weight path like gs://bucket/path/meta-llama/saves/llama3-8b/lora/sft>"
# print lora weight path
print(f"------------- Using lora_weight: {lora_weight} ---------\n")
instance = {
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"raw_response": raw_response,
}
if lora_weight:
instance["dynamic-lora"] = lora_weight
instances = [instance]
# Call the predict method on the correct client object
logging.info(f"Sending prediction request: {instances}")
response = client.predict(
endpoint=endpoints["vllm_gpu"], # Pass the endpoint path
instances=instances,
)
logging.debug(f"Received response: {response}")
for prediction in response.predictions:
print(prediction)
prompt = input("Enter your prompt: ")
# @param {type: "string"}
# @markdown If you encounter the issue like `ServiceUnavailable: 503 Took too long to respond when processing`, you can reduce the maximum number of output tokens, such as set `max_tokens` as 20.
max_tokens = 50 # @param {type:"integer"}
temperature = 1.0 # @param {type:"number"}
top_p = 1.0 # @param {type:"number"}
top_k = 1 # @param {type:"integer"}
raw_response = False # @param {type:"boolean"}
predict_vllm(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
raw_response=raw_response,
)