This repository has been archived by the owner on Jun 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
options.go
110 lines (94 loc) · 2.31 KB
/
options.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
package gptj
type PredictOptions struct {
ContextSize, RepeatLastN, Tokens, TopK, Batch int
TopP, Temperature, ContextErase, RepeatPenalty float64
}
type PredictOption func(p *PredictOptions)
var DefaultOptions PredictOptions = PredictOptions{
Tokens: 200,
TopK: 10,
TopP: 0.90,
Temperature: 0.96,
Batch: 1,
ContextErase: 0.55,
ContextSize: 1024,
RepeatLastN: 10,
RepeatPenalty: 1.2,
}
// SetTokens sets the number of tokens to generate.
func SetTokens(tokens int) PredictOption {
return func(p *PredictOptions) {
p.Tokens = tokens
}
}
// SetTopK sets the value for top-K sampling.
func SetTopK(topk int) PredictOption {
return func(p *PredictOptions) {
p.TopK = topk
}
}
// SetTopP sets the value for nucleus sampling.
func SetTopP(topp float64) PredictOption {
return func(p *PredictOptions) {
p.TopP = topp
}
}
// SetRepeatPenalty sets the repeat penalty.
func SetRepeatPenalty(ce float64) PredictOption {
return func(p *PredictOptions) {
p.RepeatPenalty = ce
}
}
// SetRepeatLastN sets the RepeatLastN.
func SetRepeatLastN(ce int) PredictOption {
return func(p *PredictOptions) {
p.RepeatLastN = ce
}
}
// SetContextErase sets the context erase %.
func SetContextErase(ce float64) PredictOption {
return func(p *PredictOptions) {
p.ContextErase = ce
}
}
// SetTemperature sets the temperature value for text generation.
func SetTemperature(temp float64) PredictOption {
return func(p *PredictOptions) {
p.Temperature = temp
}
}
// SetBatch sets the batch size.
func SetBatch(size int) PredictOption {
return func(p *PredictOptions) {
p.Batch = size
}
}
// Create a new PredictOptions object with the given options.
func NewPredictOptions(opts ...PredictOption) PredictOptions {
p := DefaultOptions
for _, opt := range opts {
opt(&p)
}
return p
}
var DefaultModelOptions ModelOptions = ModelOptions{
Threads: 4,
}
type ModelOptions struct {
Threads int
}
type ModelOption func(p *ModelOptions)
// SetThreads sets the number of threads to use for text generation.
func SetThreads(c int) ModelOption {
return func(p *ModelOptions) {
p.Threads = c
}
}
// Create a new PredictOptions object with the given options.
func NewModelOptions(opts ...ModelOption) ModelOptions {
p := DefaultModelOptions
for _, opt := range opts {
opt(&p)
}
return p
}