diff --git a/adapter_test.go b/adapter_test.go index 7141eae..3fb5bb8 100644 --- a/adapter_test.go +++ b/adapter_test.go @@ -906,6 +906,111 @@ type noopHandler struct{} func (noopHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {} +func TestWriteStringNoCompressionStatic(t *testing.T) { + t.Parallel() + var h http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if w, ok := w.(interface{ WriteString(string) (int, error) }); ok { + w.WriteString("hello string world!") + return + } + w.Write([]byte("hello bytes world!")) + }) + a, _ := DefaultAdapter() + h = a(h) + // Do not send accept-encoding to disable compression + r, _ := http.NewRequest("GET", "/", nil) + t.Run("WriteString", func(t *testing.T) { + w := &discardResponseWriterWithWriteString{} + h.ServeHTTP(w, r) + if w.s != 19 { + t.Fatalf("WriteString not called: %+v", w) + } + }) + t.Run("Write", func(t *testing.T) { + w := &discardResponseWriter{} + h.ServeHTTP(w, r) + if w.b != 18 { + t.Fatalf("Write not called: %+v", w) + } + }) +} + +func TestWriteStringNoCompressionDynamic(t *testing.T) { + t.Parallel() + var h http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/uncompressible") + if w, ok := w.(interface{ WriteString(string) (int, error) }); ok { + w.WriteString(testBody) // first WriteString will fallback to Write + w.WriteString(testBody) + return + } + w.Write([]byte(testBody)) + w.Write([]byte(testBody)) + }) + a, _ := DefaultAdapter(ContentTypes([]string{"text/uncompressible"}, true)) + h = a(h) + r, _ := http.NewRequest("GET", "/", nil) + r.Header.Set("Accept-Encoding", "gzip") + t.Run("WriteString", func(t *testing.T) { + w := &discardResponseWriterWithWriteString{} + h.ServeHTTP(w, r) + if w.s != len(testBody) || w.b != len(testBody) { // first WriteString falls back to Write + t.Fatalf("WriteString not called: %+v", w) + } + }) + t.Run("Write", func(t *testing.T) { + w := &discardResponseWriter{} + h.ServeHTTP(w, r) + if w.b != len(testBody)*2 { + t.Fatalf("Write not called: %+v", w) + } + }) +} + +type discardResponseWriterWithWriteString struct { + discardResponseWriter + s int +} + +func (w *discardResponseWriterWithWriteString) WriteString(s string) (n int, err error) { + w.s += len(s) + return len(s), nil +} + +func TestWriteStringEquivalence(t *testing.T) { + t.Parallel() + + for _, ae := range []string{"gzip", "uncompressed"} { + for _, ct := range []string{"text", "uncompressible"} { + t.Run(fmt.Sprintf("%s/%s", ae, ct), func(t *testing.T) { + r, _ := http.NewRequest("GET", "/", nil) + r.Header.Set("Accept-Encoding", ae) + a, _ := DefaultAdapter(ContentTypes([]string{"uncompressible"}, true)) + + var h http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", ct) + w.(interface{ WriteString(string) (int, error) }).WriteString(testBody) + w.(interface{ WriteString(string) (int, error) }).WriteString(testBody) + }) + h = a(h) + ws := httptest.NewRecorder() + h.ServeHTTP(ws, r) + + h = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", ct) + w.Write([]byte(testBody)) + w.Write([]byte(testBody)) + }) + h = a(h) + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + + assert.Equal(t, ws.Body.Bytes(), w.Body.Bytes(), "response body mismatch") + }) + } + } +} + // -------------------------------------------------------------------- func BenchmarkAdapter(b *testing.B) { diff --git a/response_writer.go b/response_writer.go index ec15a37..9cd97de 100644 --- a/response_writer.go +++ b/response_writer.go @@ -32,6 +32,7 @@ var ( _ io.WriteCloser = &compressWriter{} _ http.Flusher = &compressWriter{} _ http.Hijacker = &compressWriter{} + _ writeStringer = &compressWriter{} ) type compressWriterWithCloseNotify struct { @@ -46,11 +47,12 @@ var ( _ io.WriteCloser = compressWriterWithCloseNotify{} _ http.Flusher = compressWriterWithCloseNotify{} _ http.Hijacker = compressWriterWithCloseNotify{} + _ writeStringer = compressWriterWithCloseNotify{} ) const maxBuf = 1 << 16 // maximum size of recycled buffer -// Write appends data to the gzip writer. +// WriteString compresses and appends the given byte slice to the underlying ResponseWriter. func (w *compressWriter) Write(b []byte) (int, error) { if w.w != nil { // The responseWriter is already initialized: use it. @@ -108,6 +110,28 @@ func (w *compressWriter) Write(b []byte) (int, error) { return len(b), nil } +// WriteString compresses and appends the given string to the underlying ResponseWriter. +// +// This makes use of an optional method (WriteString) exposed by the compressors, or by +// the underlying ResponseWriter. +func (w *compressWriter) WriteString(s string) (int, error) { + // Since WriteString is an optional interface of the compressor, and the actual compressor + // is chosen only after the first call to Write, we can't statically know whether the interface + // is supported. We therefore have to check dynamically. + if ws, _ := w.w.(writeStringer); ws != nil { + // The responseWriter is already initialized and it implements WriteString. + return ws.WriteString(s) + } + // Fallback: the writer has not been initialized yet, or it has been initialized + // and it does not implement WriteString. We could in theory do something unsafe + // here but for now let's keep it simple and fallback to Write. + return w.Write([]byte(s)) +} + +type writeStringer interface { + WriteString(string) (int, error) +} + // startCompress initializes a compressing writer and writes the buffer. func (w *compressWriter) startCompress(enc string) error { comp, ok := w.config.compressor[enc]