forked from openai/openai-go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
polling.go
134 lines (120 loc) · 4.09 KB
/
polling.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
package openai
import (
"context"
"fmt"
"net/http"
"strconv"
"time"
"github.com/joschahenningsen/openai-go/option"
)
func mkPollingOptions(pollIntervalMs int) []option.RequestOption {
options := []option.RequestOption{option.WithHeader("X-Stainless-Poll-Helper", "true")}
if pollIntervalMs > 0 {
options = append(options, option.WithHeader("X-Stainless-Poll-Interval", fmt.Sprintf("%d", pollIntervalMs)))
}
return options
}
func getPollInterval(raw *http.Response) (ms int) {
if ms, err := strconv.Atoi(raw.Header.Get("openai-poll-after-ms")); err == nil {
return ms
}
return 1000
}
// PollStatus waits until a VectorStoreFile is no longer in an incomplete state and returns it.
// Pass 0 as pollIntervalMs to use the default polling interval of 1 second.
func (r *BetaVectorStoreFileService) PollStatus(ctx context.Context, vectorStoreID string, fileID string, pollIntervalMs int, opts ...option.RequestOption) (*VectorStoreFile, error) {
var raw *http.Response
opts = append(opts, mkPollingOptions(pollIntervalMs)...)
opts = append(opts, option.WithResponseInto(&raw))
for {
file, err := r.Get(ctx, vectorStoreID, fileID, opts...)
if err != nil {
return nil, fmt.Errorf("vector store file poll: received %w", err)
}
switch file.Status {
case VectorStoreFileStatusInProgress:
if pollIntervalMs <= 0 {
pollIntervalMs = getPollInterval(raw)
}
time.Sleep(time.Duration(pollIntervalMs) * time.Millisecond)
case VectorStoreFileStatusCancelled,
VectorStoreFileStatusCompleted,
VectorStoreFileStatusFailed:
return file, nil
default:
return nil, fmt.Errorf("invalid vector store file status during polling: received %s", file.Status)
}
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
}
}
// PollStatus waits until a BetaVectorStoreFileBatch is no longer in an incomplete state and returns it.
// Pass 0 as pollIntervalMs to use the default polling interval of 1 second.
func (r *BetaVectorStoreFileBatchService) PollStatus(ctx context.Context, vectorStoreID string, batchID string, pollIntervalMs int, opts ...option.RequestOption) (*VectorStoreFileBatch, error) {
var raw *http.Response
opts = append(opts, option.WithResponseInto(&raw))
opts = append(opts, mkPollingOptions(pollIntervalMs)...)
for {
batch, err := r.Get(ctx, vectorStoreID, batchID, opts...)
if err != nil {
return nil, fmt.Errorf("vector store file batch poll: received %w", err)
}
switch batch.Status {
case VectorStoreFileBatchStatusInProgress:
if pollIntervalMs <= 0 {
pollIntervalMs = getPollInterval(raw)
}
time.Sleep(time.Duration(pollIntervalMs) * time.Millisecond)
case VectorStoreFileBatchStatusCancelled,
VectorStoreFileBatchStatusCompleted,
VectorStoreFileBatchStatusFailed:
return batch, nil
default:
return nil, fmt.Errorf("invalid vector store file batch status during polling: received %s", batch.Status)
}
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
}
}
// PollStatus waits until a Run is no longer in an incomplete state and returns it.
// Pass 0 as pollIntervalMs to use the default polling interval of 1 second.
func (r *BetaThreadRunService) PollStatus(ctx context.Context, threadID string, runID string, pollIntervalMs int, opts ...option.RequestOption) (res *Run, err error) {
var raw *http.Response
opts = append(opts, mkPollingOptions(pollIntervalMs)...)
opts = append(opts, option.WithResponseInto(&raw))
for {
run, err := r.Get(ctx, threadID, runID, opts...)
if err != nil {
return nil, fmt.Errorf("thread run poll: received %w", err)
}
switch run.Status {
case RunStatusInProgress,
RunStatusQueued:
if pollIntervalMs <= 0 {
pollIntervalMs = getPollInterval(raw)
}
time.Sleep(time.Duration(pollIntervalMs) * time.Millisecond)
case RunStatusRequiresAction,
RunStatusCancelled,
RunStatusCompleted,
RunStatusFailed,
RunStatusExpired,
RunStatusIncomplete:
return run, nil
default:
return nil, fmt.Errorf("invalid thread run status during polling: received %s", run.Status)
}
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
break
}
}
}