Skip to content

Commit

Permalink
add WriteString
Browse files Browse the repository at this point in the history
  • Loading branch information
CAFxX committed Jan 31, 2021
1 parent b6a62f1 commit 5686044
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 1 deletion.
105 changes: 105 additions & 0 deletions adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,111 @@ type noopHandler struct{}

func (noopHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {}

func TestWriteStringNoCompressionStatic(t *testing.T) {
t.Parallel()
var h http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if w, ok := w.(interface{ WriteString(string) (int, error) }); ok {
w.WriteString("hello string world!")
return
}
w.Write([]byte("hello bytes world!"))
})
a, _ := DefaultAdapter()
h = a(h)
// Do not send accept-encoding to disable compression
r, _ := http.NewRequest("GET", "/", nil)
t.Run("WriteString", func(t *testing.T) {
w := &discardResponseWriterWithWriteString{}
h.ServeHTTP(w, r)
if w.s != 19 {
t.Fatalf("WriteString not called: %+v", w)
}
})
t.Run("Write", func(t *testing.T) {
w := &discardResponseWriter{}
h.ServeHTTP(w, r)
if w.b != 18 {
t.Fatalf("Write not called: %+v", w)
}
})
}

func TestWriteStringNoCompressionDynamic(t *testing.T) {
t.Parallel()
var h http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/uncompressible")
if w, ok := w.(interface{ WriteString(string) (int, error) }); ok {
w.WriteString(testBody) // first WriteString will fallback to Write
w.WriteString(testBody)
return
}
w.Write([]byte(testBody))
w.Write([]byte(testBody))
})
a, _ := DefaultAdapter(ContentTypes([]string{"text/uncompressible"}, true))
h = a(h)
r, _ := http.NewRequest("GET", "/", nil)
r.Header.Set("Accept-Encoding", "gzip")
t.Run("WriteString", func(t *testing.T) {
w := &discardResponseWriterWithWriteString{}
h.ServeHTTP(w, r)
if w.s != len(testBody) || w.b != len(testBody) { // first WriteString falls back to Write
t.Fatalf("WriteString not called: %+v", w)
}
})
t.Run("Write", func(t *testing.T) {
w := &discardResponseWriter{}
h.ServeHTTP(w, r)
if w.b != len(testBody)*2 {
t.Fatalf("Write not called: %+v", w)
}
})
}

type discardResponseWriterWithWriteString struct {
discardResponseWriter
s int
}

func (w *discardResponseWriterWithWriteString) WriteString(s string) (n int, err error) {
w.s += len(s)
return len(s), nil
}

func TestWriteStringEquivalence(t *testing.T) {
t.Parallel()

for _, ae := range []string{"gzip", "uncompressed"} {
for _, ct := range []string{"text", "uncompressible"} {
t.Run(fmt.Sprintf("%s/%s", ae, ct), func(t *testing.T) {
r, _ := http.NewRequest("GET", "/", nil)
r.Header.Set("Accept-Encoding", ae)
a, _ := DefaultAdapter(ContentTypes([]string{"uncompressible"}, true))

var h http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", ct)
w.(interface{ WriteString(string) (int, error) }).WriteString(testBody)
w.(interface{ WriteString(string) (int, error) }).WriteString(testBody)
})
h = a(h)
ws := httptest.NewRecorder()
h.ServeHTTP(ws, r)

h = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", ct)
w.Write([]byte(testBody))
w.Write([]byte(testBody))
})
h = a(h)
w := httptest.NewRecorder()
h.ServeHTTP(w, r)

assert.Equal(t, ws.Body.Bytes(), w.Body.Bytes(), "response body mismatch")
})
}
}
}

// --------------------------------------------------------------------

func BenchmarkAdapter(b *testing.B) {
Expand Down
26 changes: 25 additions & 1 deletion response_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ var (
_ io.WriteCloser = &compressWriter{}
_ http.Flusher = &compressWriter{}
_ http.Hijacker = &compressWriter{}
_ writeStringer = &compressWriter{}
)

type compressWriterWithCloseNotify struct {
Expand All @@ -46,11 +47,12 @@ var (
_ io.WriteCloser = compressWriterWithCloseNotify{}
_ http.Flusher = compressWriterWithCloseNotify{}
_ http.Hijacker = compressWriterWithCloseNotify{}
_ writeStringer = compressWriterWithCloseNotify{}
)

const maxBuf = 1 << 16 // maximum size of recycled buffer

// Write appends data to the gzip writer.
// WriteString compresses and appends the given byte slice to the underlying ResponseWriter.
func (w *compressWriter) Write(b []byte) (int, error) {
if w.w != nil {
// The responseWriter is already initialized: use it.
Expand Down Expand Up @@ -108,6 +110,28 @@ func (w *compressWriter) Write(b []byte) (int, error) {
return len(b), nil
}

// WriteString compresses and appends the given string to the underlying ResponseWriter.
//
// This makes use of an optional method (WriteString) exposed by the compressors, or by
// the underlying ResponseWriter.
func (w *compressWriter) WriteString(s string) (int, error) {
// Since WriteString is an optional interface of the compressor, and the actual compressor
// is chosen only after the first call to Write, we can't statically know whether the interface
// is supported. We therefore have to check dynamically.
if ws, _ := w.w.(writeStringer); ws != nil {
// The responseWriter is already initialized and it implements WriteString.
return ws.WriteString(s)
}
// Fallback: the writer has not been initialized yet, or it has been initialized
// and it does not implement WriteString. We could in theory do something unsafe
// here but for now let's keep it simple and fallback to Write.
return w.Write([]byte(s))
}

type writeStringer interface {
WriteString(string) (int, error)
}

// startCompress initializes a compressing writer and writes the buffer.
func (w *compressWriter) startCompress(enc string) error {
comp, ok := w.config.compressor[enc]
Expand Down

0 comments on commit 5686044

Please sign in to comment.