Skip to content

Commit

Permalink
GODRIVER-2914 x/mongo/driver: enable parallel zlib compression and im…
Browse files Browse the repository at this point in the history
…prove zstd decompression (#1320)

Co-authored-by: Matt Dale <9760375+matthewdale@users.noreply.github.com>
  • Loading branch information
charlievieth and matthewdale authored Sep 5, 2023
1 parent a3fcd76 commit 84a4385
Show file tree
Hide file tree
Showing 3 changed files with 349 additions and 53 deletions.
114 changes: 70 additions & 44 deletions x/mongo/driver/compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,48 +26,72 @@ type CompressionOpts struct {
UncompressedSize int32
}

var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder
// mustZstdNewWriter creates a zstd.Encoder with the given level and a nil
// destination writer. It panics on any errors and should only be used at
// package initialization time.
func mustZstdNewWriter(lvl zstd.EncoderLevel) *zstd.Encoder {
enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(lvl))
if err != nil {
panic(err)
}
return enc
}

var zstdEncoders = [zstd.SpeedBestCompression + 1]*zstd.Encoder{
0: nil, // zstd.speedNotSet
zstd.SpeedFastest: mustZstdNewWriter(zstd.SpeedFastest),
zstd.SpeedDefault: mustZstdNewWriter(zstd.SpeedDefault),
zstd.SpeedBetterCompression: mustZstdNewWriter(zstd.SpeedBetterCompression),
zstd.SpeedBestCompression: mustZstdNewWriter(zstd.SpeedBestCompression),
}

func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) {
if v, ok := zstdEncoders.Load(level); ok {
return v.(*zstd.Encoder), nil
}
encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level))
if err != nil {
return nil, err
if zstd.SpeedFastest <= level && level <= zstd.SpeedBestCompression {
return zstdEncoders[level], nil
}
zstdEncoders.Store(level, encoder)
return encoder, nil
// The level is outside the expected range, return an error.
return nil, fmt.Errorf("invalid zstd compression level: %d", level)
}

var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder
// zlibEncodersOffset is the offset into the zlibEncoders array for a given
// compression level.
const zlibEncodersOffset = -zlib.HuffmanOnly // HuffmanOnly == -2

var zlibEncoders [zlib.BestCompression + zlibEncodersOffset + 1]sync.Pool

func getZlibEncoder(level int) (*zlibEncoder, error) {
if v, ok := zlibEncoders.Load(level); ok {
return v.(*zlibEncoder), nil
}
writer, err := zlib.NewWriterLevel(nil, level)
if err != nil {
return nil, err
if zlib.HuffmanOnly <= level && level <= zlib.BestCompression {
if enc, _ := zlibEncoders[level+zlibEncodersOffset].Get().(*zlibEncoder); enc != nil {
return enc, nil
}
writer, err := zlib.NewWriterLevel(nil, level)
if err != nil {
return nil, err
}
enc := &zlibEncoder{writer: writer, level: level}
return enc, nil
}
encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)}
zlibEncoders.Store(level, encoder)
// The level is outside the expected range, return an error.
return nil, fmt.Errorf("invalid zlib compression level: %d", level)
}

return encoder, nil
func putZlibEncoder(enc *zlibEncoder) {
if enc != nil {
zlibEncoders[enc.level+zlibEncodersOffset].Put(enc)
}
}

type zlibEncoder struct {
mu sync.Mutex
writer *zlib.Writer
buf *bytes.Buffer
buf bytes.Buffer
level int
}

func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) {
e.mu.Lock()
defer e.mu.Unlock()
defer putZlibEncoder(e)

e.buf.Reset()
e.writer.Reset(e.buf)
e.writer.Reset(&e.buf)

_, err := e.writer.Write(src)
if err != nil {
Expand Down Expand Up @@ -105,8 +129,15 @@ func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
}
}

var zstdReaderPool = sync.Pool{
New: func() interface{} {
r, _ := zstd.NewReader(nil)
return r
},
}

// DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed
func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) {
func DecompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
switch opts.Compressor {
case wiremessage.CompressorNoOp:
return in, nil
Expand All @@ -117,34 +148,29 @@ func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, er
} else if int32(l) != opts.UncompressedSize {
return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l)
}
uncompressed = make([]byte, opts.UncompressedSize)
return snappy.Decode(uncompressed, in)
out := make([]byte, opts.UncompressedSize)
return snappy.Decode(out, in)
case wiremessage.CompressorZLib:
r, err := zlib.NewReader(bytes.NewReader(in))
if err != nil {
return nil, err
}
defer func() {
err = r.Close()
}()
uncompressed = make([]byte, opts.UncompressedSize)
_, err = io.ReadFull(r, uncompressed)
if err != nil {
out := make([]byte, opts.UncompressedSize)
if _, err := io.ReadFull(r, out); err != nil {
return nil, err
}
return uncompressed, nil
case wiremessage.CompressorZstd:
r, err := zstd.NewReader(bytes.NewBuffer(in))
if err != nil {
return nil, err
}
defer r.Close()
uncompressed = make([]byte, opts.UncompressedSize)
_, err = io.ReadFull(r, uncompressed)
if err != nil {
if err := r.Close(); err != nil {
return nil, err
}
return uncompressed, nil
return out, nil
case wiremessage.CompressorZstd:
buf := make([]byte, 0, opts.UncompressedSize)
// Using a pool here is about ~20% faster
// than using a single global zstd.Reader
r := zstdReaderPool.Get().(*zstd.Decoder)
out, err := r.DecodeAll(in, buf)
zstdReaderPool.Put(r)
return out, err
default:
return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
}
Expand Down
137 changes: 128 additions & 9 deletions x/mongo/driver/compression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
package driver

import (
"bytes"
"compress/zlib"
"os"
"testing"

"github.com/golang/snappy"
"github.com/klauspost/compress/zstd"

"go.mongodb.org/mongo-driver/internal/assert"
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)
Expand Down Expand Up @@ -41,6 +46,43 @@ func TestCompression(t *testing.T) {
}
}

func TestCompressionLevels(t *testing.T) {
in := []byte("abc")
wr := new(bytes.Buffer)

t.Run("ZLib", func(t *testing.T) {
opts := CompressionOpts{
Compressor: wiremessage.CompressorZLib,
}
for lvl := zlib.HuffmanOnly - 2; lvl < zlib.BestCompression+2; lvl++ {
opts.ZlibLevel = lvl
_, err1 := CompressPayload(in, opts)
_, err2 := zlib.NewWriterLevel(wr, lvl)
if err2 != nil {
assert.Error(t, err1, "expected an error for ZLib level %d", lvl)
} else {
assert.NoError(t, err1, "unexpected error for ZLib level %d", lvl)
}
}
})

t.Run("Zstd", func(t *testing.T) {
opts := CompressionOpts{
Compressor: wiremessage.CompressorZstd,
}
for lvl := zstd.SpeedFastest - 2; lvl < zstd.SpeedBestCompression+2; lvl++ {
opts.ZstdLevel = int(lvl)
_, err1 := CompressPayload(in, opts)
_, err2 := zstd.NewWriter(wr, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(opts.ZstdLevel)))
if err2 != nil {
assert.Error(t, err1, "expected an error for Zstd level %d", lvl)
} else {
assert.NoError(t, err1, "unexpected error for Zstd level %d", lvl)
}
}
})
}

func TestDecompressFailures(t *testing.T) {
t.Parallel()

Expand All @@ -62,18 +104,57 @@ func TestDecompressFailures(t *testing.T) {
})
}

func BenchmarkCompressPayload(b *testing.B) {
payload := func() []byte {
buf, err := os.ReadFile("compression.go")
var (
compressionPayload []byte
compressedSnappyPayload []byte
compressedZLibPayload []byte
compressedZstdPayload []byte
)

func initCompressionPayload(b *testing.B) {
if compressionPayload != nil {
return
}
data, err := os.ReadFile("testdata/compression.go")
if err != nil {
b.Fatal(err)
}
for i := 1; i < 10; i++ {
data = append(data, data...)
}
compressionPayload = data

compressedSnappyPayload = snappy.Encode(compressedSnappyPayload[:0], data)

{
var buf bytes.Buffer
enc, err := zstd.NewWriter(&buf, zstd.WithEncoderLevel(zstd.SpeedDefault))
if err != nil {
b.Log(err)
b.FailNow()
b.Fatal(err)
}
for i := 1; i < 10; i++ {
buf = append(buf, buf...)
compressedZstdPayload = enc.EncodeAll(data, nil)
}

{
var buf bytes.Buffer
enc := zlib.NewWriter(&buf)
if _, err := enc.Write(data); err != nil {
b.Fatal(err)
}
return buf
}()
if err := enc.Close(); err != nil {
b.Fatal(err)
}
if err := enc.Close(); err != nil {
b.Fatal(err)
}
compressedZLibPayload = append(compressedZLibPayload[:0], buf.Bytes()...)
}

b.ResetTimer()
}

func BenchmarkCompressPayload(b *testing.B) {
initCompressionPayload(b)

compressors := []wiremessage.CompressorID{
wiremessage.CompressorSnappy,
Expand All @@ -88,6 +169,9 @@ func BenchmarkCompressPayload(b *testing.B) {
ZlibLevel: wiremessage.DefaultZlibLevel,
ZstdLevel: wiremessage.DefaultZstdLevel,
}
payload := compressionPayload
b.SetBytes(int64(len(payload)))
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := CompressPayload(payload, opts)
Expand All @@ -99,3 +183,38 @@ func BenchmarkCompressPayload(b *testing.B) {
})
}
}

func BenchmarkDecompressPayload(b *testing.B) {
initCompressionPayload(b)

benchmarks := []struct {
compressor wiremessage.CompressorID
payload []byte
}{
{wiremessage.CompressorSnappy, compressedSnappyPayload},
{wiremessage.CompressorZLib, compressedZLibPayload},
{wiremessage.CompressorZstd, compressedZstdPayload},
}

for _, bench := range benchmarks {
b.Run(bench.compressor.String(), func(b *testing.B) {
opts := CompressionOpts{
Compressor: bench.compressor,
ZlibLevel: wiremessage.DefaultZlibLevel,
ZstdLevel: wiremessage.DefaultZstdLevel,
UncompressedSize: int32(len(compressionPayload)),
}
payload := bench.payload
b.SetBytes(int64(len(compressionPayload)))
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := DecompressPayload(payload, opts)
if err != nil {
b.Fatal(err)
}
}
})
})
}
}
Loading

0 comments on commit 84a4385

Please sign in to comment.