Skip to content

Commit

Permalink
x/mongo/driver: enable parallel zlib compression and improve zstd dec…
Browse files Browse the repository at this point in the history
…ompression

This commit fixes a bug where zlib compression was serialized across all
goroutines. This occurred because only one shared zlib decompresser was
instantiated for each compression level which had to be locked when used
and thus preventing concurrent compression.

This commit also slightly improves zstd decompression performance by
using a pool of zstd decoders (instantiating a zstd encoded or decoder
is fairly expensive). It also slightly cleans up the logic used to store
and acquire zstd encoders.

```
goos: darwin
goarch: arm64
pkg: go.mongodb.org/mongo-driver/x/mongo/driver
                                    │ base.20.txt  │             new.20.txt              │
                                    │    sec/op    │   sec/op     vs base                │
CompressPayload/CompressorZLib-10     5387.4µ ± 0%   651.1µ ± 1%  -87.91% (p=0.000 n=20)
CompressPayload/CompressorZstd-10      64.56µ ± 1%   64.10µ ± 0%   -0.72% (p=0.000 n=20)
DecompressPayload/CompressorZLib-10    125.7µ ± 1%   123.7µ ± 0%   -1.60% (p=0.000 n=20)
DecompressPayload/CompressorZstd-10    70.13µ ± 1%   45.80µ ± 1%  -34.70% (p=0.000 n=20)
geomean                                235.3µ        124.0µ       -47.31%

                                    │ base.20.txt  │               new.20.txt               │
                                    │     B/s      │      B/s       vs base                 │
CompressPayload/CompressorZLib-10     365.2Mi ± 0%   3021.4Mi ± 1%  +727.41% (p=0.000 n=20)
CompressPayload/CompressorZstd-10     29.76Gi ± 1%    29.97Gi ± 0%    +0.73% (p=0.000 n=20)
DecompressPayload/CompressorZLib-10   15.28Gi ± 1%    15.53Gi ± 0%    +1.63% (p=0.000 n=20)
DecompressPayload/CompressorZstd-10   27.39Gi ± 1%    41.95Gi ± 1%   +53.13% (p=0.000 n=20)
geomean                               8.164Gi         15.49Gi        +89.77%

                                    │ base.20.txt  │               new.20.txt               │
                                    │     B/op     │     B/op      vs base                  │
CompressPayload/CompressorZLib-10     14.02Ki ± 0%   14.00Ki ± 0%   -0.10% (p=0.000 n=20)
CompressPayload/CompressorZstd-10     3.398Ki ± 0%   3.398Ki ± 0%        ~ (p=1.000 n=20) ¹
DecompressPayload/CompressorZLib-10   2.008Mi ± 0%   2.008Mi ± 0%        ~ (p=1.000 n=20) ¹
DecompressPayload/CompressorZstd-10   4.109Mi ± 0%   1.969Mi ± 0%  -52.08% (p=0.000 n=20)
geomean                               142.5Ki        118.5Ki       -16.82%
¹ all samples are equal

                                    │ base.20.txt  │              new.20.txt              │
                                    │  allocs/op   │ allocs/op   vs base                  │
CompressPayload/CompressorZLib-10       1.000 ± 0%   1.000 ± 0%        ~ (p=1.000 n=20) ¹
CompressPayload/CompressorZstd-10       4.000 ± 0%   4.000 ± 0%        ~ (p=1.000 n=20) ¹
DecompressPayload/CompressorZLib-10     26.00 ± 0%   26.00 ± 0%        ~ (p=1.000 n=20) ¹
DecompressPayload/CompressorZstd-10   104.000 ± 0%   1.000 ± 0%  -99.04% (p=0.000 n=20)
geomean                                 10.20        3.193       -68.69%
¹ all samples are equal
```
  • Loading branch information
charlievieth committed Jul 15, 2023
1 parent bc82b35 commit 60ed52e
Show file tree
Hide file tree
Showing 3 changed files with 258 additions and 44 deletions.
111 changes: 67 additions & 44 deletions x/mongo/driver/compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,48 +26,70 @@ type CompressionOpts struct {
UncompressedSize int32
}

var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder
func zstdNewWriter(lvl zstd.EncoderLevel) *zstd.Encoder {
enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(lvl))
if err != nil {
panic(err)
}
return enc
}

var zstdEncoders = [zstd.SpeedBestCompression + 1]*zstd.Encoder{
0: nil, // zstd.speedNotSet
zstd.SpeedFastest: zstdNewWriter(zstd.SpeedFastest),
zstd.SpeedDefault: zstdNewWriter(zstd.SpeedDefault),
zstd.SpeedBetterCompression: zstdNewWriter(zstd.SpeedBetterCompression),
zstd.SpeedBestCompression: zstdNewWriter(zstd.SpeedBestCompression),
}

func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) {
if v, ok := zstdEncoders.Load(level); ok {
return v.(*zstd.Encoder), nil
}
encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level))
if err != nil {
return nil, err
if zstd.SpeedFastest <= level && level <= zstd.SpeedBestCompression {
return zstdEncoders[level], nil
}
zstdEncoders.Store(level, encoder)
return encoder, nil
// The level is invalid so call zstd.NewWriter for the error.
return zstd.NewWriter(nil, zstd.WithEncoderLevel(level))
}

var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder
// zlibEncodersOffset is the offset into the zlibEncoders array for a given
// compression level.
const zlibEncodersOffset = -zlib.HuffmanOnly // HuffmanOnly == -2

var zlibEncoders [zlib.BestCompression + zlibEncodersOffset + 1]sync.Pool

func getZlibEncoder(level int) (*zlibEncoder, error) {
if v, ok := zlibEncoders.Load(level); ok {
return v.(*zlibEncoder), nil
}
writer, err := zlib.NewWriterLevel(nil, level)
if err != nil {
return nil, err
if zlib.HuffmanOnly <= level && level <= zlib.BestCompression {
if enc, _ := zlibEncoders[level+zlibEncodersOffset].Get().(*zlibEncoder); enc != nil {
return enc, nil
}
writer, err := zlib.NewWriterLevel(nil, level)
if err != nil {
return nil, err
}
enc := &zlibEncoder{writer: writer, level: level}
return enc, nil
}
encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)}
zlibEncoders.Store(level, encoder)
// The level is invalid so call zlib.NewWriterLever for the error.
_, err := zlib.NewWriterLevel(nil, level)
return nil, err
}

return encoder, nil
func putZlibEncoder(enc *zlibEncoder) {
if enc != nil {
zlibEncoders[enc.level+zlibEncodersOffset].Put(enc)
}
}

type zlibEncoder struct {
mu sync.Mutex
writer *zlib.Writer
buf *bytes.Buffer
buf bytes.Buffer
level int
}

func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) {
e.mu.Lock()
defer e.mu.Unlock()
defer putZlibEncoder(e)

e.buf.Reset()
e.writer.Reset(e.buf)
e.writer.Reset(&e.buf)

_, err := e.writer.Write(src)
if err != nil {
Expand Down Expand Up @@ -105,8 +127,15 @@ func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
}
}

var zstdReaderPool = sync.Pool{
New: func() interface{} {
r, _ := zstd.NewReader(nil)
return r
},
}

// DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed
func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) {
func DecompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
switch opts.Compressor {
case wiremessage.CompressorNoOp:
return in, nil
Expand All @@ -117,34 +146,28 @@ func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, er
} else if int32(l) != opts.UncompressedSize {
return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l)
}
uncompressed = make([]byte, opts.UncompressedSize)
return snappy.Decode(uncompressed, in)
out := make([]byte, opts.UncompressedSize)
return snappy.Decode(out, in)
case wiremessage.CompressorZLib:
r, err := zlib.NewReader(bytes.NewReader(in))
if err != nil {
return nil, err
}
defer func() {
err = r.Close()
}()
uncompressed = make([]byte, opts.UncompressedSize)
_, err = io.ReadFull(r, uncompressed)
if err != nil {
out := make([]byte, opts.UncompressedSize)
if _, err := io.ReadFull(r, out); err != nil {
return nil, err
}
return uncompressed, nil
case wiremessage.CompressorZstd:
r, err := zstd.NewReader(bytes.NewBuffer(in))
if err != nil {
return nil, err
}
defer r.Close()
uncompressed = make([]byte, opts.UncompressedSize)
_, err = io.ReadFull(r, uncompressed)
if err != nil {
if err := r.Close(); err != nil {
return nil, err
}
return uncompressed, nil
return out, nil
case wiremessage.CompressorZstd:
// Using a pool here is about ~20% faster
// than using a single global zstd.Reader
r := zstdReaderPool.Get().(*zstd.Decoder)
out, err := r.DecodeAll(in, nil)
zstdReaderPool.Put(r)
return out, err
default:
return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
}
Expand Down
40 changes: 40 additions & 0 deletions x/mongo/driver/compression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,46 @@ func TestCompression(t *testing.T) {
}
}

func TestCompressionLevels(t *testing.T) {
errEq := func(e1, e2 error) bool {
if e1 == nil || e2 == nil {
return (e1 == nil) == (e2 == nil)
}
return e1.Error() == e2.Error()
}

in := []byte("abc")
wr := new(bytes.Buffer)

t.Run("ZLib", func(t *testing.T) {
opts := CompressionOpts{
Compressor: wiremessage.CompressorZLib,
}
for lvl := zlib.HuffmanOnly - 2; lvl < zlib.BestCompression+2; lvl++ {
opts.ZlibLevel = lvl
_, err1 := CompressPayload(in, opts)
_, err2 := zlib.NewWriterLevel(wr, lvl)
if !errEq(err1, err2) {
t.Fatalf("%d: error: %v, want: %v", lvl, err1, err2)
}
}
})

t.Run("Zstd", func(t *testing.T) {
opts := CompressionOpts{
Compressor: wiremessage.CompressorZstd,
}
for lvl := zstd.SpeedFastest - 2; lvl < zstd.SpeedBestCompression+2; lvl++ {
opts.ZstdLevel = int(lvl)
_, err1 := CompressPayload(in, opts)
_, err2 := zstd.NewWriter(wr, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(opts.ZstdLevel)))
if !errEq(err1, err2) {
t.Fatalf("%d: error: %v, want: %v", lvl, err1, err2)
}
}
})
}

func TestDecompressFailures(t *testing.T) {
t.Parallel()

Expand Down
151 changes: 151 additions & 0 deletions x/mongo/driver/testdata/compression.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package driver

import (
"bytes"
"compress/zlib"
"fmt"
"io"
"sync"

"github.com/golang/snappy"
"github.com/klauspost/compress/zstd"
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)

// CompressionOpts holds settings for how to compress a payload
type CompressionOpts struct {
Compressor wiremessage.CompressorID
ZlibLevel int
ZstdLevel int
UncompressedSize int32
}

var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder

func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) {
if v, ok := zstdEncoders.Load(level); ok {
return v.(*zstd.Encoder), nil
}
encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level))
if err != nil {
return nil, err
}
zstdEncoders.Store(level, encoder)
return encoder, nil
}

var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder

func getZlibEncoder(level int) (*zlibEncoder, error) {
if v, ok := zlibEncoders.Load(level); ok {
return v.(*zlibEncoder), nil
}
writer, err := zlib.NewWriterLevel(nil, level)
if err != nil {
return nil, err
}
encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)}
zlibEncoders.Store(level, encoder)

return encoder, nil
}

type zlibEncoder struct {
mu sync.Mutex
writer *zlib.Writer
buf *bytes.Buffer
}

func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) {
e.mu.Lock()
defer e.mu.Unlock()

e.buf.Reset()
e.writer.Reset(e.buf)

_, err := e.writer.Write(src)
if err != nil {
return nil, err
}
err = e.writer.Close()
if err != nil {
return nil, err
}
dst = append(dst[:0], e.buf.Bytes()...)
return dst, nil
}

// CompressPayload takes a byte slice and compresses it according to the options passed
func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
switch opts.Compressor {
case wiremessage.CompressorNoOp:
return in, nil
case wiremessage.CompressorSnappy:
return snappy.Encode(nil, in), nil
case wiremessage.CompressorZLib:
encoder, err := getZlibEncoder(opts.ZlibLevel)
if err != nil {
return nil, err
}
return encoder.Encode(nil, in)
case wiremessage.CompressorZstd:
encoder, err := getZstdEncoder(zstd.EncoderLevelFromZstd(opts.ZstdLevel))
if err != nil {
return nil, err
}
return encoder.EncodeAll(in, nil), nil
default:
return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
}
}

// DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed
func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) {
switch opts.Compressor {
case wiremessage.CompressorNoOp:
return in, nil
case wiremessage.CompressorSnappy:
l, err := snappy.DecodedLen(in)
if err != nil {
return nil, fmt.Errorf("decoding compressed length %w", err)
} else if int32(l) != opts.UncompressedSize {
return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l)
}
uncompressed = make([]byte, opts.UncompressedSize)
return snappy.Decode(uncompressed, in)
case wiremessage.CompressorZLib:
r, err := zlib.NewReader(bytes.NewReader(in))
if err != nil {
return nil, err
}
defer func() {
err = r.Close()
}()
uncompressed = make([]byte, opts.UncompressedSize)
_, err = io.ReadFull(r, uncompressed)
if err != nil {
return nil, err
}
return uncompressed, nil
case wiremessage.CompressorZstd:
r, err := zstd.NewReader(bytes.NewBuffer(in))
if err != nil {
return nil, err
}
defer r.Close()
uncompressed = make([]byte, opts.UncompressedSize)
_, err = io.ReadFull(r, uncompressed)
if err != nil {
return nil, err
}
return uncompressed, nil
default:
return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
}
}

0 comments on commit 60ed52e

Please sign in to comment.