From 873078278014e5c5be48a7b6fd7dc8cd36c63871 Mon Sep 17 00:00:00 2001 From: ayush Date: Sat, 23 Nov 2024 16:04:30 +0000 Subject: [PATCH 1/9] support for extra_body parameter for embeddings API --- client.go | 20 ++++++++++++++++++++ embeddings.go | 37 +++++++++++++++++++++++++++++++++---- 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index ed8595e0..1bb79e5b 100644 --- a/client.go +++ b/client.go @@ -84,6 +84,26 @@ func withBody(body any) requestOption { } } +func withExtraBody(extraBody map[string]any) requestOption { + return func(args *requestOptions) { + // Initialize args.body as a map[string]any if it's nil. + if args.body == nil { + args.body = make(map[string]any) + } + // Assert that args.body is a map[string]any. + bodyMap, ok := args.body.(map[string]any) + if !ok { + // If it's not, initialize it as a map[string]any. + bodyMap = make(map[string]any) + args.body = bodyMap + } + // Add extraBody fields to args.body. + for key, value := range extraBody { + bodyMap[key] = value + } + } +} + func withContentType(contentType string) requestOption { return func(args *requestOptions) { args.header.Set("Content-Type", contentType) diff --git a/embeddings.go b/embeddings.go index 4a0e682d..127e910b 100644 --- a/embeddings.go +++ b/embeddings.go @@ -159,7 +159,10 @@ type EmbeddingRequest struct { EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. - Dimensions int `json:"dimensions,omitempty"` + Dimensions int `json:"dimensions,omitempty"` + // The ExtraBody field allows for the inclusion of arbitrary key-value pairs + // in the request body that may not be explicitly defined in this struct. + ExtraBody map[string]any `json:"extra_body,omitempty"` } func (r EmbeddingRequest) Convert() EmbeddingRequest { @@ -186,7 +189,10 @@ type EmbeddingRequestStrings struct { EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. - Dimensions int `json:"dimensions,omitempty"` + Dimensions int `json:"dimensions,omitempty"` + // The ExtraBody field allows for the inclusion of arbitrary key-value pairs + // in the request body that may not be explicitly defined in this struct. + ExtraBody map[string]any `json:"extra_body,omitempty"` } func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { @@ -196,6 +202,7 @@ func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { User: r.User, EncodingFormat: r.EncodingFormat, Dimensions: r.Dimensions, + ExtraBody: r.ExtraBody, } } @@ -218,7 +225,10 @@ type EmbeddingRequestTokens struct { EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. - Dimensions int `json:"dimensions,omitempty"` + Dimensions int `json:"dimensions,omitempty"` + // The ExtraBody field allows for the inclusion of arbitrary key-value pairs + // in the request body that may not be explicitly defined in this struct. + ExtraBody map[string]any `json:"extra_body,omitempty"` } func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { @@ -228,6 +238,7 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { User: r.User, EncodingFormat: r.EncodingFormat, Dimensions: r.Dimensions, + ExtraBody: r.ExtraBody, } } @@ -241,11 +252,29 @@ func (c *Client) CreateEmbeddings( conv EmbeddingRequestConverter, ) (res EmbeddingResponse, err error) { baseReq := conv.Convert() + + // Prepare the body with only the provided fields. + // The body map is used to dynamically construct the request payload for the embedding API. + // Instead of relying on a fixed struct, the body map allows for flexible inclusion of fields + // based on their presence, avoiding unnecessary or empty fields in the request. + body := make(map[string]any) + body["input"] = baseReq.Input + body["model"] = baseReq.Model + if baseReq.User != "" { + body["user"] = baseReq.User + } + if baseReq.EncodingFormat != "" { + body["encoding_format"] = baseReq.EncodingFormat + } + if baseReq.Dimensions > 0 { // Assuming 0 means the field is not set + body["dimensions"] = baseReq.Dimensions + } req, err := c.newRequest( ctx, http.MethodPost, c.fullURL("/embeddings", withModel(string(baseReq.Model))), - withBody(baseReq), + withBody(body), // Main request body. + withExtraBody(baseReq.ExtraBody), // Merge ExtraBody fields. ) if err != nil { return From 5b15527672e324d5c4e55ff4f3d891fb4774a89b Mon Sep 17 00:00:00 2001 From: ayush Date: Sat, 23 Nov 2024 16:25:29 +0000 Subject: [PATCH 2/9] done linting --- embeddings.go | 18 +++++++++--------- test.mp3 | 1 + 2 files changed, 10 insertions(+), 9 deletions(-) create mode 100644 test.mp3 diff --git a/embeddings.go b/embeddings.go index 127e910b..f04f1a28 100644 --- a/embeddings.go +++ b/embeddings.go @@ -159,10 +159,10 @@ type EmbeddingRequest struct { EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. - Dimensions int `json:"dimensions,omitempty"` + Dimensions int `json:"dimensions,omitempty"` // The ExtraBody field allows for the inclusion of arbitrary key-value pairs // in the request body that may not be explicitly defined in this struct. - ExtraBody map[string]any `json:"extra_body,omitempty"` + ExtraBody map[string]any `json:"extra_body,omitempty"` } func (r EmbeddingRequest) Convert() EmbeddingRequest { @@ -189,10 +189,10 @@ type EmbeddingRequestStrings struct { EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. - Dimensions int `json:"dimensions,omitempty"` + Dimensions int `json:"dimensions,omitempty"` // The ExtraBody field allows for the inclusion of arbitrary key-value pairs // in the request body that may not be explicitly defined in this struct. - ExtraBody map[string]any `json:"extra_body,omitempty"` + ExtraBody map[string]any `json:"extra_body,omitempty"` } func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { @@ -225,10 +225,10 @@ type EmbeddingRequestTokens struct { EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. - Dimensions int `json:"dimensions,omitempty"` + Dimensions int `json:"dimensions,omitempty"` // The ExtraBody field allows for the inclusion of arbitrary key-value pairs // in the request body that may not be explicitly defined in this struct. - ExtraBody map[string]any `json:"extra_body,omitempty"` + ExtraBody map[string]any `json:"extra_body,omitempty"` } func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { @@ -254,8 +254,8 @@ func (c *Client) CreateEmbeddings( baseReq := conv.Convert() // Prepare the body with only the provided fields. - // The body map is used to dynamically construct the request payload for the embedding API. - // Instead of relying on a fixed struct, the body map allows for flexible inclusion of fields + // The body map is used to dynamically construct the request payload for the embedding API. + // Instead of relying on a fixed struct, the body map allows for flexible inclusion of fields // based on their presence, avoiding unnecessary or empty fields in the request. body := make(map[string]any) body["input"] = baseReq.Input @@ -273,7 +273,7 @@ func (c *Client) CreateEmbeddings( ctx, http.MethodPost, c.fullURL("/embeddings", withModel(string(baseReq.Model))), - withBody(body), // Main request body. + withBody(body), // Main request body. withExtraBody(baseReq.ExtraBody), // Merge ExtraBody fields. ) if err != nil { diff --git a/test.mp3 b/test.mp3 new file mode 100644 index 00000000..b6fc4c62 --- /dev/null +++ b/test.mp3 @@ -0,0 +1 @@ +hello \ No newline at end of file From 88775cf9bff49323ef76ac7673dcdf7e0c393302 Mon Sep 17 00:00:00 2001 From: ayush Date: Sat, 23 Nov 2024 16:43:22 +0000 Subject: [PATCH 3/9] added unit tests --- embeddings_test.go | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/embeddings_test.go b/embeddings_test.go index 43897816..3eca4c31 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -51,6 +51,24 @@ func TestEmbedding(t *testing.T) { t.Fatalf("Expected embedding request to contain model field") } + // test embedding request with strings and extra_body param + embeddingReqWithExtraBody := openai.EmbeddingRequest{ + Input: []string{ + "The food was delicious and the waiter", + "Other examples of embedding request", + }, + Model: model, + ExtraBody: map[string]any{ + "input_type": "query", + "truncate": "NONE", + }, + } + marshaled, err = json.Marshal(embeddingReqWithExtraBody) + checks.NoError(t, err, "Could not marshal embedding request") + if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { + t.Fatalf("Expected embedding request to contain model field") + } + // test embedding request with strings embeddingReqStrings := openai.EmbeddingRequestStrings{ Input: []string{ @@ -124,6 +142,21 @@ func TestEmbeddingEndpoint(t *testing.T) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) } + // test create embeddings with strings (ExtraBody in request) + res, err = client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + ExtraBody: map[string]any{ + "input_type": "query", + "truncate": "NONE", + }, + }, + ) + checks.NoError(t, err, "CreateEmbeddings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } + // test create embeddings with strings (simple embedding request) res, err = client.CreateEmbeddings( context.Background(), From b69335b889ff912b32230e72a56bb2b6ebc8c7ba Mon Sep 17 00:00:00 2001 From: ayush Date: Sat, 23 Nov 2024 18:12:19 +0000 Subject: [PATCH 4/9] improved code coverage and removed unnecessary checks --- client.go | 18 ++++++------------ embeddings_test.go | 1 + 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/client.go b/client.go index 1bb79e5b..4782726c 100644 --- a/client.go +++ b/client.go @@ -86,20 +86,14 @@ func withBody(body any) requestOption { func withExtraBody(extraBody map[string]any) requestOption { return func(args *requestOptions) { - // Initialize args.body as a map[string]any if it's nil. - if args.body == nil { - args.body = make(map[string]any) - } // Assert that args.body is a map[string]any. bodyMap, ok := args.body.(map[string]any) - if !ok { - // If it's not, initialize it as a map[string]any. - bodyMap = make(map[string]any) - args.body = bodyMap - } - // Add extraBody fields to args.body. - for key, value := range extraBody { - bodyMap[key] = value + if ok { + // If it's a map[string]any then only add extraBody + // fields to args.body otherwise keep only fields in request struct. + for key, value := range extraBody { + bodyMap[key] = value + } } } } diff --git a/embeddings_test.go b/embeddings_test.go index 3eca4c31..095a8f47 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -150,6 +150,7 @@ func TestEmbeddingEndpoint(t *testing.T) { "input_type": "query", "truncate": "NONE", }, + Dimensions: 1, }, ) checks.NoError(t, err, "CreateEmbeddings error") From 2e5a45e8b0827f305be554ad99ef6b7d1126615a Mon Sep 17 00:00:00 2001 From: ayush Date: Sun, 24 Nov 2024 08:06:29 +0000 Subject: [PATCH 5/9] test cleanup --- test.mp3 | 1 - 1 file changed, 1 deletion(-) delete mode 100644 test.mp3 diff --git a/test.mp3 b/test.mp3 deleted file mode 100644 index b6fc4c62..00000000 --- a/test.mp3 +++ /dev/null @@ -1 +0,0 @@ -hello \ No newline at end of file From fc529512613bc897b2b16de9ca55a127824ab052 Mon Sep 17 00:00:00 2001 From: ayush Date: Sun, 1 Dec 2024 05:32:13 +0000 Subject: [PATCH 6/9] updated body map creation code --- embeddings.go | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/embeddings.go b/embeddings.go index f04f1a28..12a9d3ec 100644 --- a/embeddings.go +++ b/embeddings.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "encoding/binary" + "encoding/json" "errors" "math" "net/http" @@ -253,28 +254,30 @@ func (c *Client) CreateEmbeddings( ) (res EmbeddingResponse, err error) { baseReq := conv.Convert() - // Prepare the body with only the provided fields. // The body map is used to dynamically construct the request payload for the embedding API. // Instead of relying on a fixed struct, the body map allows for flexible inclusion of fields // based on their presence, avoiding unnecessary or empty fields in the request. - body := make(map[string]any) - body["input"] = baseReq.Input - body["model"] = baseReq.Model - if baseReq.User != "" { - body["user"] = baseReq.User - } - if baseReq.EncodingFormat != "" { - body["encoding_format"] = baseReq.EncodingFormat + extraBody := baseReq.ExtraBody + baseReq.ExtraBody = nil + + // Serialize baseReq to JSON + jsonData, err := json.Marshal(baseReq) + if err != nil { + return } - if baseReq.Dimensions > 0 { // Assuming 0 means the field is not set - body["dimensions"] = baseReq.Dimensions + + // Deserialize JSON to map[string]any + var body map[string]any + if err = json.Unmarshal(jsonData, &body); err != nil { + return } + req, err := c.newRequest( ctx, http.MethodPost, c.fullURL("/embeddings", withModel(string(baseReq.Model))), - withBody(body), // Main request body. - withExtraBody(baseReq.ExtraBody), // Merge ExtraBody fields. + withBody(body), // Main request body. + withExtraBody(extraBody), // Merge ExtraBody fields. ) if err != nil { return From a30d2f88879ddcb4f1d45593c79a413e393e406f Mon Sep 17 00:00:00 2001 From: ayush Date: Sun, 1 Dec 2024 06:26:27 +0000 Subject: [PATCH 7/9] code coverage --- embeddings.go | 4 +--- embeddings_test.go | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/embeddings.go b/embeddings.go index 12a9d3ec..8593f8b5 100644 --- a/embeddings.go +++ b/embeddings.go @@ -268,9 +268,7 @@ func (c *Client) CreateEmbeddings( // Deserialize JSON to map[string]any var body map[string]any - if err = json.Unmarshal(jsonData, &body); err != nil { - return - } + _ = json.Unmarshal(jsonData, &body) req, err := c.newRequest( ctx, diff --git a/embeddings_test.go b/embeddings_test.go index 095a8f47..fb2484ee 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -158,7 +158,20 @@ func TestEmbeddingEndpoint(t *testing.T) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) } - // test create embeddings with strings (simple embedding request) + // test create embeddings with strings (ExtraBody in request and ) + _, err = client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + ExtraBody: map[string]any{ + "input_type": "query", + "truncate": make(chan int), // Channels cannot be serialized into JSON + }, + Dimensions: 1, + }, + ) + checks.HasError(t, err, "CreateEmbeddings error") + + // test failed (Serialize JSON error) res, err = client.CreateEmbeddings( context.Background(), openai.EmbeddingRequest{ From 745703a81308efe5732564f4388fd7b7042f9766 Mon Sep 17 00:00:00 2001 From: ayush Date: Sun, 1 Dec 2024 06:37:26 +0000 Subject: [PATCH 8/9] minor change --- embeddings_test.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/embeddings_test.go b/embeddings_test.go index fb2484ee..62c6f16d 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -162,11 +162,8 @@ func TestEmbeddingEndpoint(t *testing.T) { _, err = client.CreateEmbeddings( context.Background(), openai.EmbeddingRequest{ - ExtraBody: map[string]any{ - "input_type": "query", - "truncate": make(chan int), // Channels cannot be serialized into JSON - }, - Dimensions: 1, + Input: make(chan int), // Invalid UTF-8 string + Model: "example_model", }, ) checks.HasError(t, err, "CreateEmbeddings error") From e0ba3d6e02c0513c883c83c24b1b6b26626743aa Mon Sep 17 00:00:00 2001 From: ayush Date: Sun, 1 Dec 2024 06:38:40 +0000 Subject: [PATCH 9/9] updated testcase comment --- embeddings_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/embeddings_test.go b/embeddings_test.go index 62c6f16d..07f1262c 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -162,7 +162,7 @@ func TestEmbeddingEndpoint(t *testing.T) { _, err = client.CreateEmbeddings( context.Background(), openai.EmbeddingRequest{ - Input: make(chan int), // Invalid UTF-8 string + Input: make(chan int), // Channels are not serializable Model: "example_model", }, )