forked from angerman/llm-groq
-
Notifications
You must be signed in to change notification settings - Fork 0
/
llm_groq.py
121 lines (112 loc) · 4.8 KB
/
llm_groq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# [LLM](https://llm.datasette.io/) plugin providing access to [Groqcloud](http://console.groq.com) models.
# Base off of llm-mistral (https://github.com/simonw/llm-mistral)
import llm
from groq import Groq
from pydantic import Field
from typing import Optional, List, Union
@llm.hookimpl
def register_models(register):
register(LLMGroq("groq-llama2"))
register(LLMGroq("groq-llama3-8b"))
register(LLMGroq("groq-llama3-70b"))
register(LLMGroq("groq-mixtral"))
register(LLMGroq("groq-gemma"))
class LLMGroq(llm.Model):
can_stream = True
model_map: dict = {
"groq-llama2": "llama2-70b-4096",
"groq-llama3-8b": "llama3-8b-8192",
"groq-llama3-70b": "llama3-70b-8192",
"groq-mixtral": "mixtral-8x7b-32768",
"groq-gemma": "gemma-7b-it",
}
class Options(llm.Options):
temperature: Optional[float] = Field(
description=(
"Controls randomness of responses. A lower temperature leads to"
"more predictable outputs while a higher temperature results in"
"more varies and sometimes more creative outputs."
"As the temperature approaches zero, the model will become deterministic"
"and repetitive."
),
ge=0,
le=1,
default=None,
)
top_p: Optional[float] = Field(
description=(
"Controls randomness of responses. A lower temperature leads to"
"more predictable outputs while a higher temperature results in"
"more varies and sometimes more creative outputs."
"0.5 means half of all likelihood-weighted options are considered."
),
ge=0,
le=1,
default=None,
)
max_tokens: Optional[int] = Field(
description=(
"The maximum number of tokens that the model can process in a"
"single response. This limits ensures computational efficiency"
"and resource management."
"Requests can use up to 2048 tokens shared between prompt and completion."
),
ge=0,
lt=2049,
default=None,
)
stop: Optional[Union[str, List[str]]] = Field(
description=(
"A stop sequence is a predefined or user-specified text string that"
"signals an AI to stop generating content, ensuring its responses"
"remain focused and concise. Examples include punctuation marks and"
"markers like \"[end]\"."
"For this example, we will use \", 6\" so that the llm stops counting at 5."
"If multiple stop values are needed, an array of string may be passed,"
"stop=[\", 6\", \", six\", \", Six\"]"
),
default=None,
)
def __init__(self, model_id):
self.model_id = model_id
def build_messages(self, prompt, conversation):
messages = []
if not conversation:
if prompt.system:
messages.append({"role": "system", "content": prompt.system})
messages.append({"role": "user", "content": prompt.prompt})
return messages
current_system = None
for prev_response in conversation.responses:
if (
prev_response.prompt.system
and prev_response.prompt.system != current_system
):
messages.append(
{"role": "system", "content": prev_response.prompt.system}
)
current_system = prev_response.prompt.system
messages.append({"role": "user", "content": prev_response.prompt.prompt})
messages.append({"role": "assistant", "content": prev_response.text()})
if prompt.system and prompt.system != current_system:
messages.append({"role": "system", "content": prompt.system})
messages.append({"role": "user", "content": prompt.prompt})
return messages
def execute(self, prompt, stream, response, conversation):
key = llm.get_key("", "groq", "LLM_GROQ_KEY")
messages = self.build_messages(prompt, conversation)
client = Groq(api_key=key)
resp = client.chat.completions.create(
messages=messages, model=self.model_map[self.model_id],
stream=stream,
temperature=prompt.options.temperature,
top_p=prompt.options.top_p,
max_tokens=prompt.options.max_tokens,
stop=prompt.options.stop
)
if stream:
for chunk in resp:
if chunk.choices[0].delta.content:
yield from chunk.choices[0].delta.content
else:
yield from resp.choices[0].message.content