diff --git a/chat_test.go b/chat_test.go index 52cd0bdef..011c533f6 100644 --- a/chat_test.go +++ b/chat_test.go @@ -17,10 +17,20 @@ import ( ) const ( - xCustomHeader = "X-CUSTOM-HEADER" + xCustomHeader = "" xCustomHeaderValue = "test" ) +var testHttpHeaders = map[string]string{ + "X-CUSTOM-HEADER": "test", + "x-ratelimit-limit-requests": "3000", + "x-ratelimit-limit-tokens": "250000", + "x-ratelimit-remaining-requests": "2999", + "x-ratelimit-remaining-tokens": "249999", + "x-ratelimit-reset-requests": "20ms", + "x-ratelimit-reset-tokens": "1ms", +} + func TestChatCompletionsWrongModel(t *testing.T) { config := DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" @@ -90,9 +100,36 @@ func TestChatCompletionsWithHeaders(t *testing.T) { }) checks.NoError(t, err, "CreateChatCompletion error") - a := resp.Header().Get(xCustomHeader) - _ = a - if resp.Header().Get(xCustomHeader) != xCustomHeaderValue { + if resp.RatelimitLimitRequests() != testHttpHeaders["x-ratelimit-limit-requests"] { + t.Errorf("expected header %s to be %s", "x-ratelimit-limit-requests", + testHttpHeaders["x-ratelimit-limit-requests"]) + } + + if resp.RatelimitLimitTokens() != testHttpHeaders["x-ratelimit-limit-tokens"] { + t.Errorf("expected header %s to be %s", "x-ratelimit-limit-tokens", + testHttpHeaders["x-ratelimit-limit-tokens"]) + } + + if resp.RatelimitRemainingRequests() != testHttpHeaders["x-ratelimit-remaining-requests"] { + t.Errorf("expected header %s to be %s", "x-ratelimit-remaining-requests", + testHttpHeaders["x-ratelimit-remaining-requests"]) + } + + if resp.RatelimitRemainingTokens() != testHttpHeaders["x-ratelimit-remaining-tokens"] { + t.Errorf("expected header %s to be %s", "x-ratelimit-remaining-tokens", + testHttpHeaders["x-ratelimit-remaining-tokens"]) + } + + if resp.RatelimitResetRequests() != testHttpHeaders["x-ratelimit-reset-requests"] { + t.Errorf("expected header %s to be %s", "x-ratelimit-reset-requests", + testHttpHeaders["x-ratelimit-reset-requests"]) + } + + if resp.RatelimitResetTokens() != testHttpHeaders["x-ratelimit-reset-tokens"] { + t.Errorf("expected header %s to be %s", "x-ratelimit-reset-tokens", testHttpHeaders["x-ratelimit-reset-tokens"]) + } + + if resp.Header().Get("x-custom-header") != xCustomHeaderValue { t.Errorf("expected header %s to be %s", xCustomHeader, xCustomHeaderValue) } } @@ -310,7 +347,10 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { TotalTokens: inputTokens + completionTokens, } resBytes, _ = json.Marshal(res) - w.Header().Set(xCustomHeader, xCustomHeaderValue) + // set test headers + for k, v := range testHttpHeaders { + w.Header().Set(k, v) + } fmt.Fprintln(w, string(resBytes)) }