Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into sync-master-1009
Browse files Browse the repository at this point in the history
  • Loading branch information
coolbaluk committed Sep 10, 2024
2 parents 8b29269 + 643da8d commit 75363a4
Show file tree
Hide file tree
Showing 26 changed files with 826 additions and 88 deletions.
68 changes: 66 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func main() {
ctx := context.Background()

req := openai.CompletionRequest{
Model: openai.GPT3Ada,
Model: openai.GPT3Babbage002,
MaxTokens: 5,
Prompt: "Lorem ipsum",
}
Expand Down Expand Up @@ -174,7 +174,7 @@ func main() {
ctx := context.Background()

req := openai.CompletionRequest{
Model: openai.GPT3Ada,
Model: openai.GPT3Babbage002,
MaxTokens: 5,
Prompt: "Lorem ipsum",
Stream: true,
Expand Down Expand Up @@ -743,6 +743,70 @@ func main() {
}
```
</details>

<details>
<summary>Structured Outputs</summary>

```go
package main

import (
"context"
"fmt"
"log"

"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/jsonschema"
)

func main() {
client := openai.NewClient("your token")
ctx := context.Background()

type Result struct {
Steps []struct {
Explanation string `json:"explanation"`
Output string `json:"output"`
} `json:"steps"`
FinalAnswer string `json:"final_answer"`
}
var result Result
schema, err := jsonschema.GenerateSchemaForType(result)
if err != nil {
log.Fatalf("GenerateSchemaForType error: %v", err)
}
resp, err := client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
Model: openai.GPT4oMini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: "You are a helpful math tutor. Guide the user through the solution step by step.",
},
{
Role: openai.ChatMessageRoleUser,
Content: "how can I solve 8x + 7 = -23",
},
},
ResponseFormat: &openai.ChatCompletionResponseFormat{
Type: openai.ChatCompletionResponseFormatTypeJSONSchema,
JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{
Name: "math_reasoning",
Schema: schema,
Strict: true,
},
},
})
if err != nil {
log.Fatalf("CreateChatCompletion error: %v", err)
}
err = schema.Unmarshal(resp.Choices[0].Message.Content, &result)
if err != nil {
log.Fatalf("Unmarshal schema error: %v", err)
}
fmt.Println(result)
}
```
</details>
See the `examples/` folder for more.

## Frequently Asked Questions
Expand Down
101 changes: 87 additions & 14 deletions api_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,17 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) {
c := openai.NewClient(apiToken)
ctx := context.Background()

type MyStructuredResponse struct {
PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"`
CamelCase string `json:"camel_case" required:"true" description:"CamelCase"`
KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"`
SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"`
}
var result MyStructuredResponse
schema, err := jsonschema.GenerateSchemaForType(result)
if err != nil {
t.Fatal("CreateChatCompletion (use json_schema response) GenerateSchemaForType error")
}
resp, err := c.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Expand All @@ -212,27 +223,89 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) {
ResponseFormat: &openai.ChatCompletionResponseFormat{
Type: openai.ChatCompletionResponseFormatTypeJSONSchema,
JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{
Name: "cases",
Schema: jsonschema.Definition{
Type: jsonschema.Object,
Properties: map[string]jsonschema.Definition{
"PascalCase": jsonschema.Definition{Type: jsonschema.String},
"CamelCase": jsonschema.Definition{Type: jsonschema.String},
"KebabCase": jsonschema.Definition{Type: jsonschema.String},
"SnakeCase": jsonschema.Definition{Type: jsonschema.String},
},
Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"},
AdditionalProperties: false,
},
Name: "cases",
Schema: schema,
Strict: true,
},
},
},
)
checks.NoError(t, err, "CreateChatCompletion (use json_schema response) returned error")
if err == nil {
err = schema.Unmarshal(resp.Choices[0].Message.Content, &result)
checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error")
}
}

func TestChatCompletionStructuredOutputsFunctionCalling(t *testing.T) {
apiToken := os.Getenv("OPENAI_TOKEN")
if apiToken == "" {
t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.")
}

var err error
c := openai.NewClient(apiToken)
ctx := context.Background()

resp, err := c.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Model: openai.GPT4oMini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: "Please enter a string, and we will convert it into the following naming conventions:" +
"1. PascalCase: Each word starts with an uppercase letter, with no spaces or separators." +
"2. CamelCase: The first word starts with a lowercase letter, " +
"and subsequent words start with an uppercase letter, with no spaces or separators." +
"3. KebabCase: All letters are lowercase, with words separated by hyphens `-`." +
"4. SnakeCase: All letters are lowercase, with words separated by underscores `_`.",
},
{
Role: openai.ChatMessageRoleUser,
Content: "Hello World",
},
},
Tools: []openai.Tool{
{
Type: openai.ToolTypeFunction,
Function: &openai.FunctionDefinition{
Name: "display_cases",
Strict: true,
Parameters: &jsonschema.Definition{
Type: jsonschema.Object,
Properties: map[string]jsonschema.Definition{
"PascalCase": {
Type: jsonschema.String,
},
"CamelCase": {
Type: jsonschema.String,
},
"KebabCase": {
Type: jsonschema.String,
},
"SnakeCase": {
Type: jsonschema.String,
},
},
Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"},
AdditionalProperties: false,
},
},
},
},
ToolChoice: openai.ToolChoice{
Type: openai.ToolTypeFunction,
Function: openai.ToolFunction{
Name: "display_cases",
},
},
},
)
checks.NoError(t, err, "CreateChatCompletion (use structured outputs response) returned error")
var result = make(map[string]string)
err = json.Unmarshal([]byte(resp.Choices[0].Message.Content), &result)
checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error")
err = json.Unmarshal([]byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments), &result)
checks.NoError(t, err, "CreateChatCompletion (use structured outputs response) unmarshal error")
for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} {
if _, ok := result[key]; !ok {
t.Errorf("key:%s does not exist.", key)
Expand Down
24 changes: 19 additions & 5 deletions api_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,15 @@ func TestAzureFullURL(t *testing.T) {
Name string
BaseURL string
AzureModelMapper map[string]string
Suffix string
Model string
Expect string
}{
{
"AzureBaseURLWithSlashAutoStrip",
"https://httpbin.org/",
nil,
"/chat/completions",
"chatgpt-demo",
"https://httpbin.org/" +
"openai/deployments/chatgpt-demo" +
Expand All @@ -128,19 +130,28 @@ func TestAzureFullURL(t *testing.T) {
"AzureBaseURLWithoutSlashOK",
"https://httpbin.org",
nil,
"/chat/completions",
"chatgpt-demo",
"https://httpbin.org/" +
"openai/deployments/chatgpt-demo" +
"/chat/completions?api-version=2023-05-15",
},
{
"",
"https://httpbin.org",
nil,
"/assistants?limit=10",
"chatgpt-demo",
"https://httpbin.org/openai/assistants?api-version=2023-05-15&limit=10",
},
}

for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
az := DefaultAzureConfig("dummy", c.BaseURL)
cli := NewClientWithConfig(az)
// /openai/deployments/{engine}/chat/completions?api-version={api_version}
actual := cli.fullURL("/chat/completions", c.Model)
actual := cli.fullURL(c.Suffix, withModel(c.Model))
if actual != c.Expect {
t.Errorf("Expected %s, got %s", c.Expect, actual)
}
Expand All @@ -153,19 +164,22 @@ func TestCloudflareAzureFullURL(t *testing.T) {
cases := []struct {
Name string
BaseURL string
Suffix string
Expect string
}{
{
"CloudflareAzureBaseURLWithSlashAutoStrip",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/",
"/chat/completions",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" +
"chat/completions?api-version=2023-05-15",
},
{
"CloudflareAzureBaseURLWithoutSlashOK",
"",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" +
"chat/completions?api-version=2023-05-15",
"/assistants?limit=10",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo" +
"/assistants?api-version=2023-05-15&limit=10",
},
}

Expand All @@ -176,7 +190,7 @@ func TestCloudflareAzureFullURL(t *testing.T) {

cli := NewClientWithConfig(az)

actual := cli.fullURL("/chat/completions")
actual := cli.fullURL(c.Suffix)
if actual != c.Expect {
t.Errorf("Expected %s, got %s", c.Expect, actual)
}
Expand Down
9 changes: 7 additions & 2 deletions audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,13 @@ func (c *Client) callAudioAPI(
}

urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix)
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model),
withBody(&formBody), withContentType(builder.FormDataContentType()))
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(&formBody),
withContentType(builder.FormDataContentType()),
)
if err != nil {
return AudioResponse{}, err
}
Expand Down
18 changes: 11 additions & 7 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"encoding/json"
"errors"
"net/http"

"github.com/sashabaranov/go-openai/jsonschema"
)

// Chat message role defined by the OpenAI API.
Expand Down Expand Up @@ -187,10 +185,10 @@ type ChatCompletionResponseFormat struct {
}

type ChatCompletionResponseFormatJSONSchema struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Schema jsonschema.Definition `json:"schema"`
Strict bool `json:"strict"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
Schema json.Marshaler `json:"schema"`
Strict bool `json:"strict"`
}

// ChatCompletionRequest represents a request structure for chat completion API.
Expand Down Expand Up @@ -264,6 +262,7 @@ type ToolFunction struct {
type FunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Strict bool `json:"strict,omitempty"`
// Parameters is an object describing the function.
// You can pass json.RawMessage to describe the schema,
// or you can pass in a struct which serializes to the proper JSON schema.
Expand Down Expand Up @@ -358,7 +357,12 @@ func (c *Client) CreateChatCompletion(
return
}

req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(request),
)
if err != nil {
return
}
Expand Down
7 changes: 6 additions & 1 deletion chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,12 @@ func (c *Client) CreateChatCompletionStream(
}

request.Stream = true
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(request),
)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 75363a4

Please sign in to comment.