Skip to content

Commit

Permalink
add validation for o1 series models validataion + beta limitations
Browse files Browse the repository at this point in the history
  • Loading branch information
chococola committed Sep 20, 2024
1 parent 0ffcd83 commit 4d313aa
Show file tree
Hide file tree
Showing 5 changed files with 322 additions and 0 deletions.
4 changes: 4 additions & 0 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,10 @@ func (c *Client) CreateChatCompletion(
return
}

if err = validateRequestForO1Models(request); err != nil {
return
}

req, err := c.newRequest(
ctx,
http.MethodPost,
Expand Down
4 changes: 4 additions & 0 deletions chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ func (c *Client) CreateChatCompletionStream(
}

request.Stream = true
if err = validateRequestForO1Models(request); err != nil {
return
}

req, err := c.newRequest(
ctx,
http.MethodPost,
Expand Down
21 changes: 21 additions & 0 deletions chat_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,27 @@ func TestChatCompletionsStreamWrongModel(t *testing.T) {
}
}

func TestChatCompletionsStreamWithO1BetaLimitations(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1/chat/completions"
client := openai.NewClientWithConfig(config)
ctx := context.Background()

req := openai.ChatCompletionRequest{
Model: openai.O1Preview,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
}
_, err := client.CreateChatCompletionStream(ctx, req)
if !errors.Is(err, openai.ErrO1BetaLimitationsStreaming) {
t.Fatalf("CreateChatCompletion should return ErrO1BetaLimitationsStreaming, but returned: %v", err)
}
}

func TestCreateChatCompletionStream(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
Expand Down
211 changes: 211 additions & 0 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,199 @@ func TestChatCompletionsWrongModel(t *testing.T) {
checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg)
}

func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) {
tests := []struct {
name string
in openai.ChatCompletionRequest
expectedError error
}{
{
name: "o1-preview_MaxTokens_deprecated",
in: openai.ChatCompletionRequest{
MaxTokens: 5,
Model: openai.O1Preview,
},
expectedError: openai.ErrO1MaxTokensDeprecated,
},
{
name: "o1-mini_MaxTokens_deprecated",
in: openai.ChatCompletionRequest{
MaxTokens: 5,
Model: openai.O1Mini,
},
expectedError: openai.ErrO1MaxTokensDeprecated,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
ctx := context.Background()

_, err := client.CreateChatCompletion(ctx, tt.in)
checks.HasError(t, err)
msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
checks.ErrorIs(t, err, tt.expectedError, msg)
})
}
}

func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
tests := []struct {
name string
in openai.ChatCompletionRequest
expectedError error
}{
{
name: "log_probs_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionsTokens: 1000,
LogProbs: true,
Model: openai.O1Preview,
},
expectedError: openai.ErrO1BetaLimitationsLogprobs,
},
{
name: "message_type_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionsTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
},
},
},
expectedError: openai.ErrO1BetaLimitationsMessageTypes,
},
{
name: "tool_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionsTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
Tools: []openai.Tool{
{
Type: openai.ToolTypeFunction,
},
},
},
expectedError: openai.ErrO1BetaLimitationsTools,
},
{
name: "set_temperature_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionsTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
Temperature: float32(2),
},
expectedError: openai.ErrO1BetaLimitationsOther,
},
{
name: "set_top_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionsTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
Temperature: float32(1),
TopP: float32(0.1),
},
expectedError: openai.ErrO1BetaLimitationsOther,
},
{
name: "set_n_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionsTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
Temperature: float32(1),
TopP: float32(1),
N: 2,
},
expectedError: openai.ErrO1BetaLimitationsOther,
},
{
name: "set_presence_penalty_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionsTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
PresencePenalty: float32(1),
},
expectedError: openai.ErrO1BetaLimitationsOther,
},
{
name: "set_frequency_penalty_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionsTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
FrequencyPenalty: float32(0.1),
},
expectedError: openai.ErrO1BetaLimitationsOther,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
ctx := context.Background()

_, err := client.CreateChatCompletion(ctx, tt.in)
checks.HasError(t, err)
msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
checks.ErrorIs(t, err, tt.expectedError, msg)
})
}
}

func TestChatRequestOmitEmpty(t *testing.T) {
data, err := json.Marshal(openai.ChatCompletionRequest{
// We set model b/c it's required, so omitempty doesn't make sense
Expand Down Expand Up @@ -97,6 +290,24 @@ func TestChatCompletions(t *testing.T) {
checks.NoError(t, err, "CreateChatCompletion error")
}

// TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestO1ModelChatCompletions(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
Model: openai.O1Preview,
MaxCompletionsTokens: 1000,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
})
checks.NoError(t, err, "CreateChatCompletion error")
}

// TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestChatCompletionsWithHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
Expand Down
82 changes: 82 additions & 0 deletions completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,20 @@ import (
)

var (
ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionsTokens") //nolint:lll
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll
ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll
)

var (
ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll
ErrO1BetaLimitationsStreaming = errors.New("this model has beta-limitations, streaming not supported") //nolint:lll
ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll
ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll
ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
)

// GPT3 Defines the models provided by OpenAI to use when generating
// completions from OpenAI.
// GPT3 Models are designed for text-based tasks. For code-specific
Expand Down Expand Up @@ -85,6 +94,15 @@ const (
CodexCodeDavinci001 = "code-davinci-001"
)

// O1SeriesModels List of new Series of OpenAI models.
// Some old api attributes not supported

Check failure on line 98 in completion.go

View workflow job for this annotation

GitHub Actions / Sanity check

Comment should end in a period (godot)
var O1SeriesModels = map[string]struct{}{
O1Mini: {},
O1Mini20240912: {},
O1Preview: {},
O1Preview20240912: {},
}

var disabledModelsForEndpoints = map[string]map[string]bool{
"/completions": {
O1Mini: true,
Expand Down Expand Up @@ -146,6 +164,70 @@ func checkPromptType(prompt any) bool {
return isString || isStringSlice
}

var unsupportedToolsForO1Models = map[ToolType]struct{}{
ToolTypeFunction: {},
}

var availableMessageRoleForO1Models = map[string]struct{}{
ChatMessageRoleUser: {},
ChatMessageRoleAssistant: {},
}

// validateRequestForO1Models checks for deprecated fields of OpenAI models

Check failure on line 176 in completion.go

View workflow job for this annotation

GitHub Actions / Sanity check

Comment should end in a period (godot)
func validateRequestForO1Models(request ChatCompletionRequest) error {
if _, found := O1SeriesModels[request.Model]; !found {
return nil
}

if request.MaxTokens > 0 {
return ErrO1MaxTokensDeprecated
}

// Beta Limitations
// refs:https://platform.openai.com/docs/guides/reasoning/beta-limitations
// Streaming: not supported
if request.Stream {
return ErrO1BetaLimitationsStreaming
}
// Logprobs: not supported.
if request.LogProbs {
return ErrO1BetaLimitationsLogprobs
}

// Message types: user and assistant messages only, system messages are not supported.
for _, m := range request.Messages {
if _, found := availableMessageRoleForO1Models[m.Role]; !found {
return ErrO1BetaLimitationsMessageTypes
}
}

// Tools: tools, function calling, and response format parameters are not supported
for _, t := range request.Tools {
if _, found := unsupportedToolsForO1Models[t.Type]; found {
return ErrO1BetaLimitationsTools
}
}

// Other: temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0.
if request.Temperature > 0 && request.Temperature != 1 {
return ErrO1BetaLimitationsOther
}
if request.TopP > 0 && request.TopP != 1 {
return ErrO1BetaLimitationsOther
}
if request.N > 0 && request.N != 1 {
return ErrO1BetaLimitationsOther
}
if request.PresencePenalty > 0 {
return ErrO1BetaLimitationsOther
}
if request.FrequencyPenalty > 0 {
return ErrO1BetaLimitationsOther
}

return nil
}

// CompletionRequest represents a request structure for completion API.
type CompletionRequest struct {
Model string `json:"model"`
Expand Down

0 comments on commit 4d313aa

Please sign in to comment.