-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_response.py
68 lines (48 loc) · 1.92 KB
/
generate_response.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer
from dotenv import load_dotenv
load_dotenv()
HUGGING_FACE_TOKEN = os.environ["HUGGING_FACE_TOKEN"]
def generate_llm_response(context, question):
# Define a prompt template
prompt_template = """
### [INST]
You are a helpful assistant who provides cooking advice.
Use the recipe information below as relevant context to help answer the user question.
If there is no recipe in the context, answer the user question based on your general knowledge.
### CONTEXT:
{context}
### QUESTION:
{question}
### RESPONSE:
[/INST]
"""
# Model API endpoint (replace with the model ID)
model_id = "microsoft/Phi-3.5-mini-instruct"
client = InferenceClient(api_key=HUGGING_FACE_TOKEN)
# Format the prompt using the template and provided context/question
prompt = prompt_template.format(context=context, question=question)
# Initialize variables to track tokens
response_text = ""
# Generate the response using the formatted prompt
for message in client.chat_completion(
model=model_id,
messages=[{"role": "user", "content": prompt}],
max_tokens=1000,
stream=True,
):
response_chunk = message.choices[0].delta.content
response_text += response_chunk
# Print each chunk immediately
print(response_chunk, end="", flush=True) # Add flush=True for immediate output
# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Count the tokens using the tokenizer
total_tokens = len(tokenizer.encode(response_text))
return response_text, total_tokens
# # Example usage
# if __name__ == "__main__":
# response_text, total_tokens = generate_llm_response(question="Can you provide me an indian recipe?", context=None)
# print(response_text)
# print(total_tokens)