-
Notifications
You must be signed in to change notification settings - Fork 5
/
text_generation.go
184 lines (169 loc) · 6.65 KB
/
text_generation.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
package hfapigo
import (
"encoding/json"
"errors"
"fmt"
)
const (
RecommendedTextGenerationModel = "microsoft/phi-2"
TextGenerationGrammarTypeJSON = "json"
TextGenerationGrammarTypeRegex = "regex"
)
type TextGenerationRequest struct {
// (Required) a string to be generated from
Input string `json:"inputs,omitempty"`
Parameters TextGenerationParameters `json:"parameters,omitempty"`
Options Options `json:"options,omitempty"`
}
type TextGenerationParameters struct {
BestOf *int `json:"best_of,omitempty"`
DecoderInputDetails *bool `json:"decoder_input_details,omitempty"`
Details *bool `json:"details,omitempty"`
DoSample *bool `json:"do_sample,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
Grammar *string `json:"grammar,omitempty"`
MaxNewTokens *int `json:"max_new_tokens,omitempty"`
RepetitionPenalty *float64 `json:"repetition_penalty,omitempty"`
ReturnFullText *bool `json:"return_full_text,omitempty"`
Seed *int64 `json:"seed,omitempty"`
Stop []string `json:"stop,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopK *int `json:"top_k,omitempty"`
TopNTokens *int `json:"top_n_tokens,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
Truncate *int `json:"truncate,omitempty"`
TypicalP *float64 `json:"typical_p,omitempty"`
Watermark *bool `json:"watermark,omitempty"`
}
func NewTextGenerationParameters() *TextGenerationParameters {
return &TextGenerationParameters{}
}
func (params *TextGenerationParameters) SetBestOf(bestOf int) *TextGenerationParameters {
params.BestOf = &bestOf
return params
}
func (params *TextGenerationParameters) SetDecoderInputDetails(decoderInputDetails bool) *TextGenerationParameters {
params.DecoderInputDetails = &decoderInputDetails
return params
}
func (params *TextGenerationParameters) SetDetails(details bool) *TextGenerationParameters {
params.Details = &details
return params
}
func (params *TextGenerationParameters) SetDoSample(doSample bool) *TextGenerationParameters {
params.DoSample = &doSample
return params
}
func (params *TextGenerationParameters) SetFrequencyPenalty(frequencyPenalty float64) *TextGenerationParameters {
params.FrequencyPenalty = &frequencyPenalty
return params
}
func (params *TextGenerationParameters) SetGrammar(grammar string) *TextGenerationParameters {
params.Grammar = &grammar
return params
}
func (params *TextGenerationParameters) SetMaxNewTokens(maxNewTokens int) *TextGenerationParameters {
params.MaxNewTokens = &maxNewTokens
return params
}
func (params *TextGenerationParameters) SetRepetitionPenalty(repetitionPenalty float64) *TextGenerationParameters {
params.RepetitionPenalty = &repetitionPenalty
return params
}
func (params *TextGenerationParameters) SetReturnFullText(returnFullText bool) *TextGenerationParameters {
params.ReturnFullText = &returnFullText
return params
}
func (params *TextGenerationParameters) SetSeed(seed int64) *TextGenerationParameters {
params.Seed = &seed
return params
}
func (params *TextGenerationParameters) SetStop(stop []string) *TextGenerationParameters {
params.Stop = stop
return params
}
func (params *TextGenerationParameters) SetTemperature(temperature float64) *TextGenerationParameters {
params.Temperature = &temperature
return params
}
func (params *TextGenerationParameters) SetTopK(topK int) *TextGenerationParameters {
params.TopK = &topK
return params
}
func (params *TextGenerationParameters) SetTopNTokens(topNTokens int) *TextGenerationParameters {
params.TopNTokens = &topNTokens
return params
}
func (params *TextGenerationParameters) SetTopP(topP float64) *TextGenerationParameters {
params.TopP = &topP
return params
}
func (params *TextGenerationParameters) SetTruncate(truncate int) *TextGenerationParameters {
params.Truncate = &truncate
return params
}
func (params *TextGenerationParameters) SetTypicalP(typicalP float64) *TextGenerationParameters {
params.TypicalP = &typicalP
return params
}
func (params *TextGenerationParameters) SetWatermark(watermark bool) *TextGenerationParameters {
params.Watermark = &watermark
return params
}
func (params *TextGenerationParameters) SetRepetitionPenaly(penalty float64) *TextGenerationParameters {
params.RepetitionPenalty = &penalty
return params
}
type TextGenerationResponse struct {
GeneratedText string `json:"generated_text,omitempty"`
Details TextGenerationResponseDetails `json:"details,omitempty"`
}
type TextGenerationResponseDetails struct {
BestOfSequences []*TextGenerationBestOfSequence `json:"best_of_sequences,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
GeneratedTokens int `json:"generated_tokens,omitempty"`
Prefill []*TextGenerationPrefillToken `json:"prefill,omitempty"`
Seed int64 `json:"seed,omitempty"`
Tokens []*TextGenerationToken `json:"tokens,omitempty"`
TopTokens []*TextGenerationToken `json:"top_tokens,omitempty"`
}
type TextGenerationBestOfSequence struct {
FinishReason string `json:"finish_reason,omitempty"`
GeneratedText string `json:"generated_text,omitempty"`
GeneratedTokens int `json:"generated_tokens,omitempty"`
Prefill []*TextGenerationPrefillToken `json:"prefill,omitempty"`
Seed int64 `json:"seed,omitempty"`
Tokens []*TextGenerationToken `json:"tokens,omitempty"`
TopTokens [][]*TextGenerationToken `json:"top_tokens,omitempty"`
}
type TextGenerationPrefillToken struct {
ID int `json:"id,omitempty"`
LogProb float64 `json:"logprob,omitempty"`
Text string `json:"text,omitempty"`
}
type TextGenerationToken struct {
TextGenerationPrefillToken
Special bool `json:"special,omitempty"`
}
func SendTextGenerationRequest(model string, request *TextGenerationRequest) ([]*TextGenerationResponse, error) {
if request == nil {
return nil, errors.New("nil TextGenerationRequest")
}
jsonBuf, err := json.Marshal(request)
if err != nil {
return nil, err
}
respBody, err := MakeHFAPIRequest(jsonBuf, model)
if err != nil {
return nil, err
}
tgresps := make([]*TextGenerationResponse, 1)
err = json.Unmarshal(respBody, &tgresps)
if err != nil {
return nil, err
}
if len(tgresps) < 1 {
return nil, fmt.Errorf("expected at least 1 response, got none; response=%s", string(respBody))
}
return tgresps, nil
}