Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GODRIVER-2914 x/mongo/driver: enable parallel zlib compression and improve zstd decompression #1320

Merged
merged 3 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 70 additions & 44 deletions x/mongo/driver/compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,48 +26,72 @@ type CompressionOpts struct {
UncompressedSize int32
}

var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder
// mustZstdNewWriter creates a zstd.Encoder with the given level and a nil
// destination writer. It panics on any errors and should only be used at
// package initialization time.
func mustZstdNewWriter(lvl zstd.EncoderLevel) *zstd.Encoder {
enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(lvl))
if err != nil {
panic(err)
}
return enc
}

var zstdEncoders = [zstd.SpeedBestCompression + 1]*zstd.Encoder{
0: nil, // zstd.speedNotSet
zstd.SpeedFastest: mustZstdNewWriter(zstd.SpeedFastest),
zstd.SpeedDefault: mustZstdNewWriter(zstd.SpeedDefault),
zstd.SpeedBetterCompression: mustZstdNewWriter(zstd.SpeedBetterCompression),
zstd.SpeedBestCompression: mustZstdNewWriter(zstd.SpeedBestCompression),
}

func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) {
if v, ok := zstdEncoders.Load(level); ok {
return v.(*zstd.Encoder), nil
}
encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level))
if err != nil {
return nil, err
if zstd.SpeedFastest <= level && level <= zstd.SpeedBestCompression {
return zstdEncoders[level], nil
}
zstdEncoders.Store(level, encoder)
return encoder, nil
// The level is outside the expected range, return an error.
return nil, fmt.Errorf("invalid zstd compression level: %d", level)
}

var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder
// zlibEncodersOffset is the offset into the zlibEncoders array for a given
// compression level.
const zlibEncodersOffset = -zlib.HuffmanOnly // HuffmanOnly == -2

var zlibEncoders [zlib.BestCompression + zlibEncodersOffset + 1]sync.Pool

func getZlibEncoder(level int) (*zlibEncoder, error) {
if v, ok := zlibEncoders.Load(level); ok {
return v.(*zlibEncoder), nil
}
writer, err := zlib.NewWriterLevel(nil, level)
if err != nil {
return nil, err
if zlib.HuffmanOnly <= level && level <= zlib.BestCompression {
if enc, _ := zlibEncoders[level+zlibEncodersOffset].Get().(*zlibEncoder); enc != nil {
return enc, nil
}
writer, err := zlib.NewWriterLevel(nil, level)
if err != nil {
return nil, err
}
enc := &zlibEncoder{writer: writer, level: level}
return enc, nil
}
encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)}
zlibEncoders.Store(level, encoder)
// The level is outside the expected range, return an error.
return nil, fmt.Errorf("invalid zlib compression level: %d", level)
}

return encoder, nil
func putZlibEncoder(enc *zlibEncoder) {
if enc != nil {
zlibEncoders[enc.level+zlibEncodersOffset].Put(enc)
}
}

type zlibEncoder struct {
mu sync.Mutex
writer *zlib.Writer
buf *bytes.Buffer
buf bytes.Buffer
level int
}

func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) {
e.mu.Lock()
defer e.mu.Unlock()
defer putZlibEncoder(e)

e.buf.Reset()
e.writer.Reset(e.buf)
e.writer.Reset(&e.buf)

_, err := e.writer.Write(src)
if err != nil {
Expand Down Expand Up @@ -105,8 +129,15 @@ func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
}
}

var zstdReaderPool = sync.Pool{
New: func() interface{} {
r, _ := zstd.NewReader(nil)
return r
matthewdale marked this conversation as resolved.
Show resolved Hide resolved
},
}

// DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed
func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) {
func DecompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
switch opts.Compressor {
case wiremessage.CompressorNoOp:
return in, nil
Expand All @@ -117,34 +148,29 @@ func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, er
} else if int32(l) != opts.UncompressedSize {
return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l)
}
uncompressed = make([]byte, opts.UncompressedSize)
return snappy.Decode(uncompressed, in)
out := make([]byte, opts.UncompressedSize)
return snappy.Decode(out, in)
case wiremessage.CompressorZLib:
r, err := zlib.NewReader(bytes.NewReader(in))
if err != nil {
return nil, err
}
defer func() {
err = r.Close()
}()
uncompressed = make([]byte, opts.UncompressedSize)
_, err = io.ReadFull(r, uncompressed)
if err != nil {
out := make([]byte, opts.UncompressedSize)
if _, err := io.ReadFull(r, out); err != nil {
return nil, err
}
return uncompressed, nil
case wiremessage.CompressorZstd:
r, err := zstd.NewReader(bytes.NewBuffer(in))
if err != nil {
return nil, err
}
defer r.Close()
uncompressed = make([]byte, opts.UncompressedSize)
_, err = io.ReadFull(r, uncompressed)
if err != nil {
if err := r.Close(); err != nil {
return nil, err
}
return uncompressed, nil
return out, nil
case wiremessage.CompressorZstd:
buf := make([]byte, 0, opts.UncompressedSize)
// Using a pool here is about ~20% faster
// than using a single global zstd.Reader
r := zstdReaderPool.Get().(*zstd.Decoder)
out, err := r.DecodeAll(in, buf)
zstdReaderPool.Put(r)
return out, err
default:
return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
}
Expand Down
137 changes: 128 additions & 9 deletions x/mongo/driver/compression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
package driver

import (
"bytes"
"compress/zlib"
"os"
"testing"

"github.com/golang/snappy"
"github.com/klauspost/compress/zstd"

"go.mongodb.org/mongo-driver/internal/assert"
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)
Expand Down Expand Up @@ -41,6 +46,43 @@ func TestCompression(t *testing.T) {
}
}

func TestCompressionLevels(t *testing.T) {
in := []byte("abc")
wr := new(bytes.Buffer)

t.Run("ZLib", func(t *testing.T) {
opts := CompressionOpts{
Compressor: wiremessage.CompressorZLib,
}
for lvl := zlib.HuffmanOnly - 2; lvl < zlib.BestCompression+2; lvl++ {
opts.ZlibLevel = lvl
_, err1 := CompressPayload(in, opts)
_, err2 := zlib.NewWriterLevel(wr, lvl)
if err2 != nil {
assert.Error(t, err1, "expected an error for ZLib level %d", lvl)
} else {
assert.NoError(t, err1, "unexpected error for ZLib level %d", lvl)
}
}
})

t.Run("Zstd", func(t *testing.T) {
opts := CompressionOpts{
Compressor: wiremessage.CompressorZstd,
}
for lvl := zstd.SpeedFastest - 2; lvl < zstd.SpeedBestCompression+2; lvl++ {
opts.ZstdLevel = int(lvl)
_, err1 := CompressPayload(in, opts)
_, err2 := zstd.NewWriter(wr, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(opts.ZstdLevel)))
if err2 != nil {
assert.Error(t, err1, "expected an error for Zstd level %d", lvl)
} else {
assert.NoError(t, err1, "unexpected error for Zstd level %d", lvl)
}
}
})
}

func TestDecompressFailures(t *testing.T) {
t.Parallel()

Expand All @@ -62,18 +104,57 @@ func TestDecompressFailures(t *testing.T) {
})
}

func BenchmarkCompressPayload(b *testing.B) {
payload := func() []byte {
buf, err := os.ReadFile("compression.go")
var (
compressionPayload []byte
compressedSnappyPayload []byte
compressedZLibPayload []byte
compressedZstdPayload []byte
)

func initCompressionPayload(b *testing.B) {
if compressionPayload != nil {
return
}
data, err := os.ReadFile("testdata/compression.go")
if err != nil {
b.Fatal(err)
}
for i := 1; i < 10; i++ {
data = append(data, data...)
}
compressionPayload = data

compressedSnappyPayload = snappy.Encode(compressedSnappyPayload[:0], data)

{
var buf bytes.Buffer
enc, err := zstd.NewWriter(&buf, zstd.WithEncoderLevel(zstd.SpeedDefault))
if err != nil {
b.Log(err)
b.FailNow()
b.Fatal(err)
}
for i := 1; i < 10; i++ {
buf = append(buf, buf...)
compressedZstdPayload = enc.EncodeAll(data, nil)
}

{
var buf bytes.Buffer
enc := zlib.NewWriter(&buf)
if _, err := enc.Write(data); err != nil {
b.Fatal(err)
}
return buf
}()
if err := enc.Close(); err != nil {
b.Fatal(err)
}
if err := enc.Close(); err != nil {
b.Fatal(err)
}
compressedZLibPayload = append(compressedZLibPayload[:0], buf.Bytes()...)
}

b.ResetTimer()
}

func BenchmarkCompressPayload(b *testing.B) {
initCompressionPayload(b)

compressors := []wiremessage.CompressorID{
wiremessage.CompressorSnappy,
Expand All @@ -88,6 +169,9 @@ func BenchmarkCompressPayload(b *testing.B) {
ZlibLevel: wiremessage.DefaultZlibLevel,
ZstdLevel: wiremessage.DefaultZstdLevel,
}
payload := compressionPayload
b.SetBytes(int64(len(payload)))
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := CompressPayload(payload, opts)
Expand All @@ -99,3 +183,38 @@ func BenchmarkCompressPayload(b *testing.B) {
})
}
}

func BenchmarkDecompressPayload(b *testing.B) {
initCompressionPayload(b)

benchmarks := []struct {
compressor wiremessage.CompressorID
payload []byte
}{
{wiremessage.CompressorSnappy, compressedSnappyPayload},
{wiremessage.CompressorZLib, compressedZLibPayload},
{wiremessage.CompressorZstd, compressedZstdPayload},
}

for _, bench := range benchmarks {
b.Run(bench.compressor.String(), func(b *testing.B) {
opts := CompressionOpts{
Compressor: bench.compressor,
ZlibLevel: wiremessage.DefaultZlibLevel,
ZstdLevel: wiremessage.DefaultZstdLevel,
UncompressedSize: int32(len(compressionPayload)),
}
payload := bench.payload
b.SetBytes(int64(len(compressionPayload)))
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := DecompressPayload(payload, opts)
if err != nil {
b.Fatal(err)
}
}
})
})
}
}
Loading
Loading