Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
l-winston authored Sep 26, 2024
2 parents 498d335 + fdd59d9 commit 11af77b
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 31 deletions.
4 changes: 3 additions & 1 deletion chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ type ChatCompletionRequest struct {
MaxTokens int `json:"max_tokens,omitempty"`
// MaxCompletionTokens An upper bound for the number of tokens that can be generated for a completion,
// including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning

MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
Expand All @@ -219,7 +220,8 @@ type ChatCompletionRequest struct {
ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"`
Seed *int `json:"seed,omitempty"`
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
// LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string.

// LogitBias i s must be a token id string (specified by their token ID in the tokenizer), not a word string.
// incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}`
// refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias
LogitBias map[string]int `json:"logit_bias,omitempty"`
Expand Down
9 changes: 9 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,18 @@ func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newB
}

func (c *Client) handleErrorResp(resp *http.Response) error {
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") {
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("error, reading response body: %w", err)
}
return fmt.Errorf("error, status code: %d, status: %s, body: %s", resp.StatusCode, resp.Status, body)
}
var errRes ErrorResponse
err := json.NewDecoder(resp.Body).Decode(&errRes)
if err != nil || errRes.Error == nil {
reqErr := &RequestError{
HTTPStatus: resp.Status,
HTTPStatusCode: resp.StatusCode,
Err: err,
}
Expand All @@ -298,6 +306,7 @@ func (c *Client) handleErrorResp(resp *http.Response) error {
return reqErr
}

errRes.Error.HTTPStatus = resp.Status
errRes.Error.HTTPStatusCode = resp.StatusCode
return errRes.Error
}
Expand Down
76 changes: 53 additions & 23 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,17 @@ func TestHandleErrorResp(t *testing.T) {
client := NewClient(mockToken)

testCases := []struct {
name string
httpCode int
body io.Reader
expected string
name string
httpCode int
httpStatus string
contentType string
body io.Reader
expected string
}{
{
name: "401 Invalid Authentication",
httpCode: http.StatusUnauthorized,
name: "401 Invalid Authentication",
httpCode: http.StatusUnauthorized,
contentType: "application/json",
body: bytes.NewReader([]byte(
`{
"error":{
Expand All @@ -152,11 +155,12 @@ func TestHandleErrorResp(t *testing.T) {
}
}`,
)),
expected: "error, status code: 401, message: You didn't provide an API key. ....",
expected: "error, status code: 401, status: , message: You didn't provide an API key. ....",
},
{
name: "401 Azure Access Denied",
httpCode: http.StatusUnauthorized,
name: "401 Azure Access Denied",
httpCode: http.StatusUnauthorized,
contentType: "application/json",
body: bytes.NewReader([]byte(
`{
"error":{
Expand All @@ -165,11 +169,12 @@ func TestHandleErrorResp(t *testing.T) {
}
}`,
)),
expected: "error, status code: 401, message: Access denied due to Virtual Network/Firewall rules.",
expected: "error, status code: 401, status: , message: Access denied due to Virtual Network/Firewall rules.",
},
{
name: "503 Model Overloaded",
httpCode: http.StatusServiceUnavailable,
name: "503 Model Overloaded",
httpCode: http.StatusServiceUnavailable,
contentType: "application/json",
body: bytes.NewReader([]byte(`
{
"error":{
Expand All @@ -179,22 +184,53 @@ func TestHandleErrorResp(t *testing.T) {
"code":null
}
}`)),
expected: "error, status code: 503, message: That model...",
expected: "error, status code: 503, status: , message: That model...",
},
{
name: "503 no message (Unknown response)",
httpCode: http.StatusServiceUnavailable,
name: "503 no message (Unknown response)",
httpCode: http.StatusServiceUnavailable,
contentType: "application/json",
body: bytes.NewReader([]byte(`
{
"error":{}
}`)),
expected: "error, status code: 503, message: ",
expected: "error, status code: 503, status: , message: ",
},
{
name: "413 Request Entity Too Large",
httpCode: http.StatusRequestEntityTooLarge,
contentType: "text/html",
body: bytes.NewReader([]byte(`<html>
<head><title>413 Request Entity Too Large</title></head>
<body>
<center><h1>413 Request Entity Too Large</h1></center>
<hr><center>nginx</center>
</body>
</html>`)),
expected: `error, status code: 413, status: , body: <html>
<head><title>413 Request Entity Too Large</title></head>
<body>
<center><h1>413 Request Entity Too Large</h1></center>
<hr><center>nginx</center>
</body>
</html>`,
},
{
name: "errorReader",
httpCode: http.StatusRequestEntityTooLarge,
contentType: "text/html",
body: &errorReader{err: errors.New("errorReader")},
expected: "error, reading response body: errorReader",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
testCase := &http.Response{}
testCase := &http.Response{
Header: map[string][]string{
"Content-Type": {tc.contentType},
},
}
testCase.StatusCode = tc.httpCode
testCase.Body = io.NopCloser(tc.body)
err := client.handleErrorResp(testCase)
Expand All @@ -203,12 +239,6 @@ func TestHandleErrorResp(t *testing.T) {
t.Errorf("Unexpected error: %v , expected: %s", err, tc.expected)
t.Fail()
}

e := &APIError{}
if !errors.As(err, &e) {
t.Errorf("(%s) Expected error to be of type APIError", tc.name)
t.Fail()
}
})
}
}
Expand Down
12 changes: 9 additions & 3 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@ package openai

// Usage Represents the total token usage per request to OpenAI.
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details"`
}

// CompletionTokensDetails Breakdown of tokens used in a completion.
type CompletionTokensDetails struct {
ReasoningTokens int `json:"reasoning_tokens"`
}
6 changes: 4 additions & 2 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type APIError struct {
Message string `json:"message"`
Param *string `json:"param,omitempty"`
Type string `json:"type"`
HTTPStatus string `json:"-"`
HTTPStatusCode int `json:"-"`
InnerError *InnerError `json:"innererror,omitempty"`
}
Expand All @@ -25,6 +26,7 @@ type InnerError struct {

// RequestError provides information about generic request errors.
type RequestError struct {
HTTPStatus string
HTTPStatusCode int
Err error
}
Expand All @@ -35,7 +37,7 @@ type ErrorResponse struct {

func (e *APIError) Error() string {
if e.HTTPStatusCode > 0 {
return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Message)
return fmt.Sprintf("error, status code: %d, status: %s, message: %s", e.HTTPStatusCode, e.HTTPStatus, e.Message)
}

return e.Message
Expand Down Expand Up @@ -101,7 +103,7 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) {
}

func (e *RequestError) Error() string {
return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Err)
return fmt.Sprintf("error, status code: %d, status: %s, message: %s", e.HTTPStatusCode, e.HTTPStatus, e.Err)
}

func (e *RequestError) Unwrap() error {
Expand Down
1 change: 1 addition & 0 deletions files_api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ func TestGetFileContentReturnError(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, wantErrorResp)
})
Expand Down
4 changes: 2 additions & 2 deletions run.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ type Run struct {
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
// ThreadTruncationStrategy defines the truncation strategy to use for the thread.
TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"`
// Disable the default behavior of parallel tool calls by setting it: false.
ParallelToolCalls any `json:"parallel_tool_calls,omitempty"`

httpHeader
}
Expand Down Expand Up @@ -112,6 +110,8 @@ type RunRequest struct {
ToolChoice any `json:"tool_choice,omitempty"`
// This can be either a string or a ResponseFormat object.
ResponseFormat any `json:"response_format,omitempty"`
// Disable the default behavior of parallel tool calls by setting it: false.
ParallelToolCalls any `json:"parallel_tool_calls,omitempty"`
}

// ThreadTruncationStrategy defines the truncation strategy to use for the thread.
Expand Down

0 comments on commit 11af77b

Please sign in to comment.