Skip to content

Commit

Permalink
Merge branch 'sashabaranov:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
eiixy authored Aug 13, 2024
2 parents 5f14c63 + 2c6889e commit ea1609c
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 18 deletions.
61 changes: 61 additions & 0 deletions api_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package openai_test

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -246,3 +247,63 @@ func TestAPIError(t *testing.T) {
t.Fatal("Empty error message occurred")
}
}

func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) {
apiToken := os.Getenv("OPENAI_TOKEN")
if apiToken == "" {
t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.")
}

var err error
c := openai.NewClient(apiToken)
ctx := context.Background()

resp, err := c.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Model: openai.GPT4oMini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: "Please enter a string, and we will convert it into the following naming conventions:" +
"1. PascalCase: Each word starts with an uppercase letter, with no spaces or separators." +
"2. CamelCase: The first word starts with a lowercase letter, " +
"and subsequent words start with an uppercase letter, with no spaces or separators." +
"3. KebabCase: All letters are lowercase, with words separated by hyphens `-`." +
"4. SnakeCase: All letters are lowercase, with words separated by underscores `_`.",
},
{
Role: openai.ChatMessageRoleUser,
Content: "Hello World",
},
},
ResponseFormat: &openai.ChatCompletionResponseFormat{
Type: openai.ChatCompletionResponseFormatTypeJSONSchema,
JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{
Name: "cases",
Schema: jsonschema.Definition{
Type: jsonschema.Object,
Properties: map[string]jsonschema.Definition{
"PascalCase": jsonschema.Definition{Type: jsonschema.String},
"CamelCase": jsonschema.Definition{Type: jsonschema.String},
"KebabCase": jsonschema.Definition{Type: jsonschema.String},
"SnakeCase": jsonschema.Definition{Type: jsonschema.String},
},
Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"},
AdditionalProperties: false,
},
Strict: true,
},
},
},
)
checks.NoError(t, err, "CreateChatCompletion (use json_schema response) returned error")
var result = make(map[string]string)
err = json.Unmarshal([]byte(resp.Choices[0].Message.Content), &result)
checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error")
for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} {
if _, ok := result[key]; !ok {
t.Errorf("key:%s does not exist.", key)
}
}
}
13 changes: 12 additions & 1 deletion chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"encoding/json"
"errors"
"net/http"

"github.com/sashabaranov/go-openai/jsonschema"
)

// Chat message role defined by the OpenAI API.
Expand Down Expand Up @@ -175,11 +177,20 @@ type ChatCompletionResponseFormatType string

const (
ChatCompletionResponseFormatTypeJSONObject ChatCompletionResponseFormatType = "json_object"
ChatCompletionResponseFormatTypeJSONSchema ChatCompletionResponseFormatType = "json_schema"
ChatCompletionResponseFormatTypeText ChatCompletionResponseFormatType = "text"
)

type ChatCompletionResponseFormat struct {
Type ChatCompletionResponseFormatType `json:"type,omitempty"`
Type ChatCompletionResponseFormatType `json:"type,omitempty"`
JSONSchema *ChatCompletionResponseFormatJSONSchema `json:"json_schema,omitempty"`
}

type ChatCompletionResponseFormatJSONSchema struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Schema jsonschema.Definition `json:"schema"`
Strict bool `json:"strict"`
}

// ChatCompletionRequest represents a request structure for chat completion API.
Expand Down
35 changes: 19 additions & 16 deletions completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const (
GPT40314 = "gpt-4-0314"
GPT4o = "gpt-4o"
GPT4o20240513 = "gpt-4o-2024-05-13"
GPT4o20240806 = "gpt-4o-2024-08-06"
GPT4oMini = "gpt-4o-mini"
GPT4oMini20240718 = "gpt-4o-mini-2024-07-18"
GPT4Turbo = "gpt-4-turbo"
Expand Down Expand Up @@ -91,6 +92,7 @@ var disabledModelsForEndpoints = map[string]map[string]bool{
GPT4: true,
GPT4o: true,
GPT4o20240513: true,
GPT4o20240806: true,
GPT4oMini: true,
GPT4oMini20240718: true,
GPT4TurboPreview: true,
Expand Down Expand Up @@ -136,25 +138,26 @@ func checkPromptType(prompt any) bool {

// CompletionRequest represents a request structure for completion API.
type CompletionRequest struct {
Model string `json:"model"`
Prompt any `json:"prompt,omitempty"`
Suffix string `json:"suffix,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Stream bool `json:"stream,omitempty"`
LogProbs int `json:"logprobs,omitempty"`
Echo bool `json:"echo,omitempty"`
Stop []string `json:"stop,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"`
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
BestOf int `json:"best_of,omitempty"`
Model string `json:"model"`
Prompt any `json:"prompt,omitempty"`
BestOf int `json:"best_of,omitempty"`
Echo bool `json:"echo,omitempty"`
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
// LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string.
// incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}`
// refs: https://platform.openai.com/docs/api-reference/completions/create#completions/create-logit_bias
LogitBias map[string]int `json:"logit_bias,omitempty"`
User string `json:"user,omitempty"`
LogitBias map[string]int `json:"logit_bias,omitempty"`
LogProbs int `json:"logprobs,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
N int `json:"n,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"`
Seed *int `json:"seed,omitempty"`
Stop []string `json:"stop,omitempty"`
Stream bool `json:"stream,omitempty"`
Suffix string `json:"suffix,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
User string `json:"user,omitempty"`
}

// CompletionChoice represents one of possible completions.
Expand Down
8 changes: 7 additions & 1 deletion jsonschema/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,17 @@ type Definition struct {
// one element, where each element is unique. You will probably only use this with strings.
Enum []string `json:"enum,omitempty"`
// Properties describes the properties of an object, if the schema type is Object.
Properties map[string]Definition `json:"properties"`
Properties map[string]Definition `json:"properties,omitempty"`
// Required specifies which properties are required, if the schema type is Object.
Required []string `json:"required,omitempty"`
// Items specifies which data type an array contains, if the schema type is Array.
Items *Definition `json:"items,omitempty"`
// AdditionalProperties is used to control the handling of properties in an object
// that are not explicitly defined in the properties section of the schema. example:
// additionalProperties: true
// additionalProperties: false
// additionalProperties: jsonschema.Definition{Type: jsonschema.String}
AdditionalProperties any `json:"additionalProperties,omitempty"`
}

func (d Definition) MarshalJSON() ([]byte, error) {
Expand Down

0 comments on commit ea1609c

Please sign in to comment.