Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add max_completions_tokens for o1 series models #857

Merged
merged 2 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,18 +200,25 @@ type ChatCompletionResponseFormatJSONSchema struct {

// ChatCompletionRequest represents a request structure for chat completion API.
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Stream bool `json:"stream,omitempty"`
Stop []string `json:"stop,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"`
ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"`
Seed *int `json:"seed,omitempty"`
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
// MaxTokens The maximum number of tokens that can be generated in the chat completion.
// This value can be used to control costs for text generated via API.
// This value is now deprecated in favor of max_completion_tokens, and is not compatible with o1 series models.
// refs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens
MaxTokens int `json:"max_tokens,omitempty"`
// MaxCompletionsTokens An upper bound for the number of tokens that can be generated for a completion,
// including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning
MaxCompletionsTokens int `json:"max_completions_tokens,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Stream bool `json:"stream,omitempty"`
Stop []string `json:"stop,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"`
ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"`
Seed *int `json:"seed,omitempty"`
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
// LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string.
// incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}`
// refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias
Expand Down Expand Up @@ -364,6 +371,10 @@ func (c *Client) CreateChatCompletion(
return
}

if err = validateRequestForO1Models(request); err != nil {
return
}

req, err := c.newRequest(
ctx,
http.MethodPost,
Expand Down
4 changes: 4 additions & 0 deletions chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ func (c *Client) CreateChatCompletionStream(
}

request.Stream = true
if err = validateRequestForO1Models(request); err != nil {
return
}

req, err := c.newRequest(
ctx,
http.MethodPost,
Expand Down
21 changes: 21 additions & 0 deletions chat_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,27 @@ func TestChatCompletionsStreamWrongModel(t *testing.T) {
}
}

func TestChatCompletionsStreamWithO1BetaLimitations(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1/chat/completions"
client := openai.NewClientWithConfig(config)
ctx := context.Background()

req := openai.ChatCompletionRequest{
Model: openai.O1Preview,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
}
_, err := client.CreateChatCompletionStream(ctx, req)
if !errors.Is(err, openai.ErrO1BetaLimitationsStreaming) {
t.Fatalf("CreateChatCompletion should return ErrO1BetaLimitationsStreaming, but returned: %v", err)
}
}

func TestCreateChatCompletionStream(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
Expand Down
211 changes: 211 additions & 0 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,199 @@ func TestChatCompletionsWrongModel(t *testing.T) {
checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg)
}

func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) {
tests := []struct {
name string
in openai.ChatCompletionRequest
expectedError error
}{
{
name: "o1-preview_MaxTokens_deprecated",
in: openai.ChatCompletionRequest{
MaxTokens: 5,
Model: openai.O1Preview,
},
expectedError: openai.ErrO1MaxTokensDeprecated,
},
{
name: "o1-mini_MaxTokens_deprecated",
in: openai.ChatCompletionRequest{
MaxTokens: 5,
Model: openai.O1Mini,
},
expectedError: openai.ErrO1MaxTokensDeprecated,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
ctx := context.Background()

_, err := client.CreateChatCompletion(ctx, tt.in)
checks.HasError(t, err)
msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
checks.ErrorIs(t, err, tt.expectedError, msg)
})
}
}

func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
tests := []struct {
name string
in openai.ChatCompletionRequest
expectedError error
}{
{
name: "log_probs_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionsTokens: 1000,
LogProbs: true,
Model: openai.O1Preview,
},
expectedError: openai.ErrO1BetaLimitationsLogprobs,
},
{
name: "message_type_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionsTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
},
},
},
expectedError: openai.ErrO1BetaLimitationsMessageTypes,
},
{
name: "tool_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionsTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
Tools: []openai.Tool{
{
Type: openai.ToolTypeFunction,
},
},
},
expectedError: openai.ErrO1BetaLimitationsTools,
},
{
name: "set_temperature_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionsTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
Temperature: float32(2),
},
expectedError: openai.ErrO1BetaLimitationsOther,
},
{
name: "set_top_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionsTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
Temperature: float32(1),
TopP: float32(0.1),
},
expectedError: openai.ErrO1BetaLimitationsOther,
},
{
name: "set_n_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionsTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
Temperature: float32(1),
TopP: float32(1),
N: 2,
},
expectedError: openai.ErrO1BetaLimitationsOther,
},
{
name: "set_presence_penalty_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionsTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
PresencePenalty: float32(1),
},
expectedError: openai.ErrO1BetaLimitationsOther,
},
{
name: "set_frequency_penalty_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionsTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
FrequencyPenalty: float32(0.1),
},
expectedError: openai.ErrO1BetaLimitationsOther,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
ctx := context.Background()

_, err := client.CreateChatCompletion(ctx, tt.in)
checks.HasError(t, err)
msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
checks.ErrorIs(t, err, tt.expectedError, msg)
})
}
}

func TestChatRequestOmitEmpty(t *testing.T) {
data, err := json.Marshal(openai.ChatCompletionRequest{
// We set model b/c it's required, so omitempty doesn't make sense
Expand Down Expand Up @@ -97,6 +290,24 @@ func TestChatCompletions(t *testing.T) {
checks.NoError(t, err, "CreateChatCompletion error")
}

// TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestO1ModelChatCompletions(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
Model: openai.O1Preview,
MaxCompletionsTokens: 1000,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
})
checks.NoError(t, err, "CreateChatCompletion error")
}

// TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestChatCompletionsWithHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
Expand Down
Loading
Loading