From c9c258b74343487ef89e11cb206bf950218850ee Mon Sep 17 00:00:00 2001 From: WqyJh <781345688@qq.com> Date: Mon, 11 Nov 2024 19:42:35 +0800 Subject: [PATCH] feat: support gpt-4o-audio-preview --- .gitignore | 1 + chat.go | 143 ++++++++++++++++-------- chat_stream.go | 17 ++- chat_test.go | 191 +++++++++++++++++++++++++++++++++ common.go | 9 +- completion.go | 62 +++++------ internal/test/checks/checks.go | 103 ++++++++++++++++++ 7 files changed, 444 insertions(+), 82 deletions(-) diff --git a/.gitignore b/.gitignore index 99b40bf17..4eba3ba84 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ # Test binary, built with `go test -c` *.test +test.mp3 # Output of the go coverage tool, specifically when used with LiteIDE *.out diff --git a/chat.go b/chat.go index 2b13f8dd7..d2f9a8e05 100644 --- a/chat.go +++ b/chat.go @@ -78,17 +78,63 @@ type ChatMessageImageURL struct { Detail ImageURLDetail `json:"detail,omitempty"` } +type AudioVoice string + +const ( + AudioVoiceAlloy AudioVoice = "alloy" + AudioVoiceAsh AudioVoice = "ash" + AudioVoiceBallad AudioVoice = "ballad" + AudioVoiceCoral AudioVoice = "coral" + AudioVoiceEcho AudioVoice = "echo" + AudioVoiceSage AudioVoice = "sage" + AudioVoiceShimmer AudioVoice = "shimmer" + AudioVoiceVerse AudioVoice = "verse" +) + +type AudioFormat string + +const ( + AudioFormatWAV AudioFormat = "wav" + AudioFormatMP3 AudioFormat = "mp3" + AudioFormatFLAC AudioFormat = "flac" + AudioFormatOPUS AudioFormat = "opus" + AudioFormatPCM16 AudioFormat = "pcm16" +) + +type ChatMessageAudio struct { + // Base64 encoded audio data. + Data string `json:"data,omitempty"` + // The format of the encoded audio data. Currently supports "wav" and "mp3". + Format AudioFormat `json:"format,omitempty"` +} + +type Modality string + +const ( + ModalityAudio Modality = "audio" + ModalityText Modality = "text" +) + +type AudioOutput struct { + // The voice the model uses to respond. Supported voices are alloy, ash, ballad, coral, echo, sage, shimmer, and verse. + Voice AudioVoice `json:"voice"` + // Specifies the output audio format. Must be one of wav, mp3, flac, opus, or pcm16. + Format AudioFormat `json:"format"` +} + type ChatMessagePartType string const ( - ChatMessagePartTypeText ChatMessagePartType = "text" - ChatMessagePartTypeImageURL ChatMessagePartType = "image_url" + ChatMessagePartTypeText ChatMessagePartType = "text" + ChatMessagePartTypeImageURL ChatMessagePartType = "image_url" + ChatMessagePartTypeInputAudio ChatMessagePartType = "input_audio" ) type ChatMessagePart struct { - Type ChatMessagePartType `json:"type,omitempty"` - Text string `json:"text,omitempty"` - ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` + Type ChatMessagePartType `json:"type,omitempty"` + Text string `json:"text,omitempty"` + ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` + InputAudio *ChatMessageAudio `json:"input_audio,omitempty"` } type ChatCompletionMessage struct { @@ -110,6 +156,33 @@ type ChatCompletionMessage struct { // For Role=tool prompts this should be set to the ID given in the assistant's prior request to call a tool. ToolCallID string `json:"tool_call_id,omitempty"` + + // If the audio output modality is requested, this object contains data about the audio response from the model. + Audio *ChatCompletionAudio `json:"audio,omitempty"` +} + +type chatCompletionMessageMultiContent struct { + Role string `json:"role"` + Content string `json:"-"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"content,omitempty"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + Audio *ChatCompletionAudio `json:"audio,omitempty"` +} + +type chatCompletionMessageSingleContent struct { + Role string `json:"role"` + Content string `json:"content"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"-"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + Audio *ChatCompletionAudio `json:"audio,omitempty"` } func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { @@ -117,58 +190,22 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { return nil, ErrContentFieldsMisused } if len(m.MultiContent) > 0 { - msg := struct { - Role string `json:"role"` - Content string `json:"-"` - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart `json:"content,omitempty"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` - }(m) + msg := chatCompletionMessageMultiContent(m) return json.Marshal(msg) } - msg := struct { - Role string `json:"role"` - Content string `json:"content"` - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart `json:"-"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` - }(m) + msg := chatCompletionMessageSingleContent(m) return json.Marshal(msg) } func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { - msg := struct { - Role string `json:"role"` - Content string `json:"content"` - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` - }{} + msg := chatCompletionMessageSingleContent{} if err := json.Unmarshal(bs, &msg); err == nil { *m = ChatCompletionMessage(msg) return nil } - multiMsg := struct { - Role string `json:"role"` - Content string - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart `json:"content"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` - }{} + multiMsg := chatCompletionMessageMultiContent{} if err := json.Unmarshal(bs, &multiMsg); err != nil { return err } @@ -176,6 +213,17 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { return nil } +type ChatCompletionAudio struct { + // Unique identifier for this audio response. + ID string `json:"id"` + // The Unix timestamp (in seconds) for when this audio response will no longer be accessible on the server for use in multi-turn conversations. + ExpiresAt int64 `json:"expires_at"` + // Base64 encoded audio bytes generated by the model, in the format specified in the request. + Data string `json:"data"` + // Transcript of the audio generated by the model. + Transcript string `json:"transcript"` +} + type ToolCall struct { // Index is not nil only in chat completion chunk object Index *int `json:"index,omitempty"` @@ -260,6 +308,11 @@ type ChatCompletionRequest struct { Store bool `json:"store,omitempty"` // Metadata to store with the completion. Metadata map[string]string `json:"metadata,omitempty"` + // Output types that you would like the model to generate for this request. Most models are capable of generating text, which is the default: ["text"] + // The gpt-4o-audio-preview model can also be used to generate audio. To request that this model generate both text and audio responses, you can use: ["text", "audio"] + Modalities []Modality `json:"modalities,omitempty"` + // Parameters for audio output. Required when audio output is requested with modalities: ["audio"] + Audio *AudioOutput `json:"audio,omitempty"` } type StreamOptions struct { diff --git a/chat_stream.go b/chat_stream.go index 58b2651c0..f2e6fb34a 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -5,12 +5,19 @@ import ( "net/http" ) +type ChatCompletionStreamChoiceDeltaAudio struct { + ID string `json:"id,omitempty"` + Transcript string `json:"transcript,omitempty"` + Data string `json:"data,omitempty"` +} + type ChatCompletionStreamChoiceDelta struct { - Content string `json:"content,omitempty"` - Role string `json:"role,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - Refusal string `json:"refusal,omitempty"` + Content string `json:"content,omitempty"` + Role string `json:"role,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + Refusal string `json:"refusal,omitempty"` + Audio *ChatCompletionStreamChoiceDeltaAudio `json:"audio,omitempty"` } type ChatCompletionStreamChoiceLogprobs struct { diff --git a/chat_test.go b/chat_test.go index 134026cdb..0d11ba2a4 100644 --- a/chat_test.go +++ b/chat_test.go @@ -764,3 +764,194 @@ func TestFinishReason(t *testing.T) { } } } + +func TestChatCompletionRequestAudio(t *testing.T) { + cases := []struct { + request openai.ChatCompletionRequest + want string + }{ + { + request: openai.ChatCompletionRequest{ + Model: openai.GPT4oAudioPreview, + Modalities: []openai.Modality{openai.ModalityText, openai.ModalityAudio}, + Audio: &openai.AudioOutput{ + Voice: "alloy", + Format: "pcm16", + }, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Is a golden retriever a good family dog?", + }, + }, + }, + want: `{"model":"gpt-4o-audio-preview","modalities":["text","audio"],"audio":{"voice":"alloy","format":"pcm16"},"messages":[{"role":"user","content":"Is a golden retriever a good family dog?"}]}`, + }, + } + + for _, c := range cases { + resBytes, _ := json.Marshal(c.request) + checks.JSONEq(t, c.want, string(resBytes)) + + var expected openai.ChatCompletionRequest + err := json.Unmarshal([]byte(c.want), &expected) + checks.NoError(t, err) + checks.Equal(t, c.request, expected) + } +} + +func TestChatCompletionResponseAudio(t *testing.T) { + cases := []struct { + response openai.ChatCompletionResponse + want string + }{ + { + response: openai.ChatCompletionResponse{ + ID: "chatcmpl-ASKCthZk3MUOqqRh64CbUbeTmZ6xl", + Object: "chat.completion", + Created: 1731314223, + Model: openai.GPT4oAudioPreview20241001, + Choices: []openai.ChatCompletionChoice{ + { + Index: 0, + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + Audio: &openai.ChatCompletionAudio{ + ID: "audio_6731c23369048190aee358c51e0373d2", + Data: "base64 encoded data", + ExpiresAt: 1731317827, + Transcript: "Yes, golden retrievers are known to be excellent family dogs. They are friendly, gentle, and great with children. Golden retrievers are also intelligent and eager to please, making them easy to train. They tend to get along well with other pets and are known for their loyalty and protective nature.", + }, + }, + FinishReason: openai.FinishReasonStop, + }, + }, + Usage: openai.Usage{ + PromptTokens: 17, + CompletionTokens: 483, + TotalTokens: 500, + PromptTokensDetails: &openai.PromptTokensDetails{ + CachedTokens: 0, + AudioTokens: 0, + TextTokens: 17, + ImageTokens: 0, + }, + CompletionTokensDetails: &openai.CompletionTokensDetails{ + ReasoningTokens: 0, + AudioTokens: 398, + TextTokens: 85, + AcceptedPredictionTokens: 0, + RejectedPredictionTokens: 0, + }, + }, + SystemFingerprint: "fp_49254d0e9b", + }, + want: `{"id":"chatcmpl-ASKCthZk3MUOqqRh64CbUbeTmZ6xl","object":"chat.completion","created":1731314223,"model":"gpt-4o-audio-preview-2024-10-01","choices":[{"index":0,"message":{"role":"assistant","content":null,"refusal":null,"audio":{"id":"audio_6731c23369048190aee358c51e0373d2","data":"base64 encoded data","expires_at":1731317827,"transcript":"Yes, golden retrievers are known to be excellent family dogs. They are friendly, gentle, and great with children. Golden retrievers are also intelligent and eager to please, making them easy to train. They tend to get along well with other pets and are known for their loyalty and protective nature."}},"finish_reason":"stop"}],"usage":{"prompt_tokens":17,"completion_tokens":483,"total_tokens":500,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0,"text_tokens":17,"image_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":398,"text_tokens":85,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}},"system_fingerprint":"fp_49254d0e9b"}`, + }, + } + + for _, c := range cases { + var expected openai.ChatCompletionResponse + err := json.Unmarshal([]byte(c.want), &expected) + checks.NoError(t, err) + checks.Equal(t, c.response, expected) + } +} + +func TestChatCompletionStreamResponseAudio(t *testing.T) { + cases := []struct { + response openai.ChatCompletionStreamResponse + want string + }{ + { + response: openai.ChatCompletionStreamResponse{ + ID: "chatcmpl-ASK8gd4isaVFw7qClLmtrcwWvka7p", + Object: "chat.completion.chunk", + Created: 1731313962, + Model: openai.GPT4oAudioPreview20241001, + SystemFingerprint: "fp_49254d0e9b", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Audio: &openai.ChatCompletionStreamChoiceDeltaAudio{ + ID: "audio_6731c12b1c148190bb8db8af1330221a", + Transcript: "Yes", + }, + }, + }, + }, + }, + want: `{"id":"chatcmpl-ASK8gd4isaVFw7qClLmtrcwWvka7p","object":"chat.completion.chunk","created":1731313962,"model":"gpt-4o-audio-preview-2024-10-01","system_fingerprint":"fp_49254d0e9b","choices":[{"index":0,"delta":{"content":null,"audio":{"id":"audio_6731c12b1c148190bb8db8af1330221a","transcript":"Yes"}},"finish_reason":null}]}`, + }, + { + response: openai.ChatCompletionStreamResponse{ + ID: "chatcmpl-ASK8gd4isaVFw7qClLmtrcwWvka7p", + Object: "chat.completion.chunk", + Created: 1731313962, + Model: openai.GPT4oAudioPreview20241001, + SystemFingerprint: "fp_49254d0e9b", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Audio: &openai.ChatCompletionStreamChoiceDeltaAudio{ + Transcript: ",", + }, + }, + }, + }, + }, + want: `{"id":"chatcmpl-ASK8gd4isaVFw7qClLmtrcwWvka7p","object":"chat.completion.chunk","created":1731313962,"model":"gpt-4o-audio-preview-2024-10-01","system_fingerprint":"fp_49254d0e9b","choices":[{"index":0,"delta":{"audio":{"transcript":","}},"finish_reason":null}]}`, + }, + { + response: openai.ChatCompletionStreamResponse{ + ID: "chatcmpl-ASK8gd4isaVFw7qClLmtrcwWvka7p", + Object: "chat.completion.chunk", + Created: 1731313962, + Model: openai.GPT4oAudioPreview20241001, + SystemFingerprint: "fp_49254d0e9b", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Role: openai.ChatMessageRoleAssistant, + Audio: &openai.ChatCompletionStreamChoiceDeltaAudio{ + ID: "audio_6731c12b1c148190bb8db8af1330221a", + Data: "base64 encoded data", + }, + }, + }, + }, + }, + want: `{"id":"chatcmpl-ASK8gd4isaVFw7qClLmtrcwWvka7p","object":"chat.completion.chunk","created":1731313962,"model":"gpt-4o-audio-preview-2024-10-01","system_fingerprint":"fp_49254d0e9b","choices":[{"index":0,"delta":{"role":"assistant","content":null,"refusal":null,"audio":{"id":"audio_6731c12b1c148190bb8db8af1330221a","data":"base64 encoded data"}},"finish_reason":null}]}`, + }, + { + response: openai.ChatCompletionStreamResponse{ + ID: "chatcmpl-ASK8gd4isaVFw7qClLmtrcwWvka7p", + Object: "chat.completion.chunk", + Created: 1731313962, + Model: openai.GPT4oAudioPreview20241001, + SystemFingerprint: "fp_49254d0e9b", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Audio: &openai.ChatCompletionStreamChoiceDeltaAudio{ + Data: "base64 encoded data", + }, + }, + }, + }, + }, + want: `{"id":"chatcmpl-ASK8gd4isaVFw7qClLmtrcwWvka7p","object":"chat.completion.chunk","created":1731313962,"model":"gpt-4o-audio-preview-2024-10-01","system_fingerprint":"fp_49254d0e9b","choices":[{"index":0,"delta":{"audio":{"data":"base64 encoded data"}},"finish_reason":null}]}`, + }, + } + + for _, c := range cases { + var expected openai.ChatCompletionStreamResponse + err := json.Unmarshal([]byte(c.want), &expected) + checks.NoError(t, err) + checks.Equal(t, c.response, expected) + } +} diff --git a/common.go b/common.go index 8cc7289c0..5923b0ddd 100644 --- a/common.go +++ b/common.go @@ -13,12 +13,17 @@ type Usage struct { // CompletionTokensDetails Breakdown of tokens used in a completion. type CompletionTokensDetails struct { - AudioTokens int `json:"audio_tokens"` - ReasoningTokens int `json:"reasoning_tokens"` + AudioTokens int `json:"audio_tokens"` + ReasoningTokens int `json:"reasoning_tokens"` + TextTokens int `json:"text_tokens,omitempty"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"` } // PromptTokensDetails Breakdown of tokens used in the prompt. type PromptTokensDetails struct { AudioTokens int `json:"audio_tokens"` CachedTokens int `json:"cached_tokens"` + TextTokens int `json:"text_tokens,omitempty"` + ImageTokens int `json:"image_tokens,omitempty"` } diff --git a/completion.go b/completion.go index 77ea8c3ab..d74c2b43e 100644 --- a/completion.go +++ b/completion.go @@ -26,36 +26,38 @@ var ( // GPT3 Models are designed for text-based tasks. For code-specific // tasks, please refer to the Codex series of models. const ( - O1Mini = "o1-mini" - O1Mini20240912 = "o1-mini-2024-09-12" - O1Preview = "o1-preview" - O1Preview20240912 = "o1-preview-2024-09-12" - GPT432K0613 = "gpt-4-32k-0613" - GPT432K0314 = "gpt-4-32k-0314" - GPT432K = "gpt-4-32k" - GPT40613 = "gpt-4-0613" - GPT40314 = "gpt-4-0314" - GPT4o = "gpt-4o" - GPT4o20240513 = "gpt-4o-2024-05-13" - GPT4o20240806 = "gpt-4o-2024-08-06" - GPT4oLatest = "chatgpt-4o-latest" - GPT4oMini = "gpt-4o-mini" - GPT4oMini20240718 = "gpt-4o-mini-2024-07-18" - GPT4Turbo = "gpt-4-turbo" - GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09" - GPT4Turbo0125 = "gpt-4-0125-preview" - GPT4Turbo1106 = "gpt-4-1106-preview" - GPT4TurboPreview = "gpt-4-turbo-preview" - GPT4VisionPreview = "gpt-4-vision-preview" - GPT4 = "gpt-4" - GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125" - GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" - GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" - GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" - GPT3Dot5Turbo16K = "gpt-3.5-turbo-16k" - GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613" - GPT3Dot5Turbo = "gpt-3.5-turbo" - GPT3Dot5TurboInstruct = "gpt-3.5-turbo-instruct" + O1Mini = "o1-mini" + O1Mini20240912 = "o1-mini-2024-09-12" + O1Preview = "o1-preview" + O1Preview20240912 = "o1-preview-2024-09-12" + GPT432K0613 = "gpt-4-32k-0613" + GPT432K0314 = "gpt-4-32k-0314" + GPT432K = "gpt-4-32k" + GPT40613 = "gpt-4-0613" + GPT40314 = "gpt-4-0314" + GPT4o = "gpt-4o" + GPT4o20240513 = "gpt-4o-2024-05-13" + GPT4o20240806 = "gpt-4o-2024-08-06" + GPT4oLatest = "chatgpt-4o-latest" + GPT4oAudioPreview = "gpt-4o-audio-preview" + GPT4oAudioPreview20241001 = "gpt-4o-audio-preview-2024-10-01" + GPT4oMini = "gpt-4o-mini" + GPT4oMini20240718 = "gpt-4o-mini-2024-07-18" + GPT4Turbo = "gpt-4-turbo" + GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09" + GPT4Turbo0125 = "gpt-4-0125-preview" + GPT4Turbo1106 = "gpt-4-1106-preview" + GPT4TurboPreview = "gpt-4-turbo-preview" + GPT4VisionPreview = "gpt-4-vision-preview" + GPT4 = "gpt-4" + GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125" + GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" + GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" + GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" + GPT3Dot5Turbo16K = "gpt-3.5-turbo-16k" + GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613" + GPT3Dot5Turbo = "gpt-3.5-turbo" + GPT3Dot5TurboInstruct = "gpt-3.5-turbo-instruct" // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci003 = "text-davinci-003" // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. diff --git a/internal/test/checks/checks.go b/internal/test/checks/checks.go index 6bd0964c6..4c8a65de2 100644 --- a/internal/test/checks/checks.go +++ b/internal/test/checks/checks.go @@ -1,7 +1,10 @@ package checks import ( + "bytes" + "encoding/json" "errors" + "reflect" "testing" ) @@ -53,3 +56,103 @@ func ErrorIsNotf(t *testing.T, err, target error, format string, msg ...string) t.Fatalf(format, msg) } } + +type TestingT interface { + Fatalf(format string, args ...any) + Errorf(format string, args ...any) +} + +type tHelper interface { + Helper() +} + +// Equal asserts that two objects are equal. +// +// assert.Equal(t, 123, 123) +// +// Pointer variable equality is determined based on the equality of the +// referenced values (as opposed to the memory addresses). Function equality +// cannot be determined and will always fail. +func Equal(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if err := validateEqualArgs(expected, actual); err != nil { + t.Fatalf("Invalid operation: %#v == %#v (%s)", expected, actual, err) + } + + if !ObjectsAreEqual(expected, actual) { + t.Fatalf("Not equal: \n"+ + "expected: %+v\n"+ + "actual : %+v", expected, actual) + } + + return true +} + +// JSONEq asserts that two JSON strings are equivalent. +// +// assert.JSONEq(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`) +func JSONEq(t TestingT, expected string, actual string, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + var expectedJSONAsInterface, actualJSONAsInterface interface{} + + if err := json.Unmarshal([]byte(expected), &expectedJSONAsInterface); err != nil { + t.Fatalf("Expected value ('%s') is not valid json.\nJSON parsing error: '%s'", expected, err.Error()) + } + + if err := json.Unmarshal([]byte(actual), &actualJSONAsInterface); err != nil { + t.Fatalf("Input ('%s') needs to be valid json.\nJSON parsing error: '%s'", actual, err.Error()) + } + + return Equal(t, expectedJSONAsInterface, actualJSONAsInterface, msgAndArgs...) +} + +// validateEqualArgs checks whether provided arguments can be safely used in the +// Equal/NotEqual functions. +func validateEqualArgs(expected, actual interface{}) error { + if expected == nil && actual == nil { + return nil + } + + if isFunction(expected) || isFunction(actual) { + return errors.New("cannot take func type as argument") + } + return nil +} + +func isFunction(arg interface{}) bool { + if arg == nil { + return false + } + return reflect.TypeOf(arg).Kind() == reflect.Func +} + +/* + Helper functions +*/ + +// ObjectsAreEqual determines if two objects are considered equal. +// +// This function does no assertion of any kind. +func ObjectsAreEqual(expected, actual interface{}) bool { + if expected == nil || actual == nil { + return expected == actual + } + + exp, ok := expected.([]byte) + if !ok { + return reflect.DeepEqual(expected, actual) + } + + act, ok := actual.([]byte) + if !ok { + return false + } + if exp == nil || act == nil { + return exp == nil && act == nil + } + return bytes.Equal(exp, act) +}