From 60ed52e4ba4e5595f51db1bdd815abc9fb8de275 Mon Sep 17 00:00:00 2001 From: Charlie Vieth Date: Fri, 14 Jul 2023 21:36:19 -0400 Subject: [PATCH] x/mongo/driver: enable parallel zlib compression and improve zstd decompression MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit fixes a bug where zlib compression was serialized across all goroutines. This occurred because only one shared zlib decompresser was instantiated for each compression level which had to be locked when used and thus preventing concurrent compression. This commit also slightly improves zstd decompression performance by using a pool of zstd decoders (instantiating a zstd encoded or decoder is fairly expensive). It also slightly cleans up the logic used to store and acquire zstd encoders. ``` goos: darwin goarch: arm64 pkg: go.mongodb.org/mongo-driver/x/mongo/driver │ base.20.txt │ new.20.txt │ │ sec/op │ sec/op vs base │ CompressPayload/CompressorZLib-10 5387.4µ ± 0% 651.1µ ± 1% -87.91% (p=0.000 n=20) CompressPayload/CompressorZstd-10 64.56µ ± 1% 64.10µ ± 0% -0.72% (p=0.000 n=20) DecompressPayload/CompressorZLib-10 125.7µ ± 1% 123.7µ ± 0% -1.60% (p=0.000 n=20) DecompressPayload/CompressorZstd-10 70.13µ ± 1% 45.80µ ± 1% -34.70% (p=0.000 n=20) geomean 235.3µ 124.0µ -47.31% │ base.20.txt │ new.20.txt │ │ B/s │ B/s vs base │ CompressPayload/CompressorZLib-10 365.2Mi ± 0% 3021.4Mi ± 1% +727.41% (p=0.000 n=20) CompressPayload/CompressorZstd-10 29.76Gi ± 1% 29.97Gi ± 0% +0.73% (p=0.000 n=20) DecompressPayload/CompressorZLib-10 15.28Gi ± 1% 15.53Gi ± 0% +1.63% (p=0.000 n=20) DecompressPayload/CompressorZstd-10 27.39Gi ± 1% 41.95Gi ± 1% +53.13% (p=0.000 n=20) geomean 8.164Gi 15.49Gi +89.77% │ base.20.txt │ new.20.txt │ │ B/op │ B/op vs base │ CompressPayload/CompressorZLib-10 14.02Ki ± 0% 14.00Ki ± 0% -0.10% (p=0.000 n=20) CompressPayload/CompressorZstd-10 3.398Ki ± 0% 3.398Ki ± 0% ~ (p=1.000 n=20) ¹ DecompressPayload/CompressorZLib-10 2.008Mi ± 0% 2.008Mi ± 0% ~ (p=1.000 n=20) ¹ DecompressPayload/CompressorZstd-10 4.109Mi ± 0% 1.969Mi ± 0% -52.08% (p=0.000 n=20) geomean 142.5Ki 118.5Ki -16.82% ¹ all samples are equal │ base.20.txt │ new.20.txt │ │ allocs/op │ allocs/op vs base │ CompressPayload/CompressorZLib-10 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=20) ¹ CompressPayload/CompressorZstd-10 4.000 ± 0% 4.000 ± 0% ~ (p=1.000 n=20) ¹ DecompressPayload/CompressorZLib-10 26.00 ± 0% 26.00 ± 0% ~ (p=1.000 n=20) ¹ DecompressPayload/CompressorZstd-10 104.000 ± 0% 1.000 ± 0% -99.04% (p=0.000 n=20) geomean 10.20 3.193 -68.69% ¹ all samples are equal ``` --- x/mongo/driver/compression.go | 111 +++++++++++------- x/mongo/driver/compression_test.go | 40 +++++++ x/mongo/driver/testdata/compression.go | 151 +++++++++++++++++++++++++ 3 files changed, 258 insertions(+), 44 deletions(-) create mode 100644 x/mongo/driver/testdata/compression.go diff --git a/x/mongo/driver/compression.go b/x/mongo/driver/compression.go index 7f355f61a4..03a057df5e 100644 --- a/x/mongo/driver/compression.go +++ b/x/mongo/driver/compression.go @@ -26,48 +26,70 @@ type CompressionOpts struct { UncompressedSize int32 } -var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder +func zstdNewWriter(lvl zstd.EncoderLevel) *zstd.Encoder { + enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(lvl)) + if err != nil { + panic(err) + } + return enc +} + +var zstdEncoders = [zstd.SpeedBestCompression + 1]*zstd.Encoder{ + 0: nil, // zstd.speedNotSet + zstd.SpeedFastest: zstdNewWriter(zstd.SpeedFastest), + zstd.SpeedDefault: zstdNewWriter(zstd.SpeedDefault), + zstd.SpeedBetterCompression: zstdNewWriter(zstd.SpeedBetterCompression), + zstd.SpeedBestCompression: zstdNewWriter(zstd.SpeedBestCompression), +} func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) { - if v, ok := zstdEncoders.Load(level); ok { - return v.(*zstd.Encoder), nil - } - encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level)) - if err != nil { - return nil, err + if zstd.SpeedFastest <= level && level <= zstd.SpeedBestCompression { + return zstdEncoders[level], nil } - zstdEncoders.Store(level, encoder) - return encoder, nil + // The level is invalid so call zstd.NewWriter for the error. + return zstd.NewWriter(nil, zstd.WithEncoderLevel(level)) } -var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder +// zlibEncodersOffset is the offset into the zlibEncoders array for a given +// compression level. +const zlibEncodersOffset = -zlib.HuffmanOnly // HuffmanOnly == -2 + +var zlibEncoders [zlib.BestCompression + zlibEncodersOffset + 1]sync.Pool func getZlibEncoder(level int) (*zlibEncoder, error) { - if v, ok := zlibEncoders.Load(level); ok { - return v.(*zlibEncoder), nil - } - writer, err := zlib.NewWriterLevel(nil, level) - if err != nil { - return nil, err + if zlib.HuffmanOnly <= level && level <= zlib.BestCompression { + if enc, _ := zlibEncoders[level+zlibEncodersOffset].Get().(*zlibEncoder); enc != nil { + return enc, nil + } + writer, err := zlib.NewWriterLevel(nil, level) + if err != nil { + return nil, err + } + enc := &zlibEncoder{writer: writer, level: level} + return enc, nil } - encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)} - zlibEncoders.Store(level, encoder) + // The level is invalid so call zlib.NewWriterLever for the error. + _, err := zlib.NewWriterLevel(nil, level) + return nil, err +} - return encoder, nil +func putZlibEncoder(enc *zlibEncoder) { + if enc != nil { + zlibEncoders[enc.level+zlibEncodersOffset].Put(enc) + } } type zlibEncoder struct { - mu sync.Mutex writer *zlib.Writer - buf *bytes.Buffer + buf bytes.Buffer + level int } func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) { - e.mu.Lock() - defer e.mu.Unlock() + defer putZlibEncoder(e) e.buf.Reset() - e.writer.Reset(e.buf) + e.writer.Reset(&e.buf) _, err := e.writer.Write(src) if err != nil { @@ -105,8 +127,15 @@ func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { } } +var zstdReaderPool = sync.Pool{ + New: func() interface{} { + r, _ := zstd.NewReader(nil) + return r + }, +} + // DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed -func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) { +func DecompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { switch opts.Compressor { case wiremessage.CompressorNoOp: return in, nil @@ -117,34 +146,28 @@ func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, er } else if int32(l) != opts.UncompressedSize { return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l) } - uncompressed = make([]byte, opts.UncompressedSize) - return snappy.Decode(uncompressed, in) + out := make([]byte, opts.UncompressedSize) + return snappy.Decode(out, in) case wiremessage.CompressorZLib: r, err := zlib.NewReader(bytes.NewReader(in)) if err != nil { return nil, err } - defer func() { - err = r.Close() - }() - uncompressed = make([]byte, opts.UncompressedSize) - _, err = io.ReadFull(r, uncompressed) - if err != nil { + out := make([]byte, opts.UncompressedSize) + if _, err := io.ReadFull(r, out); err != nil { return nil, err } - return uncompressed, nil - case wiremessage.CompressorZstd: - r, err := zstd.NewReader(bytes.NewBuffer(in)) - if err != nil { - return nil, err - } - defer r.Close() - uncompressed = make([]byte, opts.UncompressedSize) - _, err = io.ReadFull(r, uncompressed) - if err != nil { + if err := r.Close(); err != nil { return nil, err } - return uncompressed, nil + return out, nil + case wiremessage.CompressorZstd: + // Using a pool here is about ~20% faster + // than using a single global zstd.Reader + r := zstdReaderPool.Get().(*zstd.Decoder) + out, err := r.DecodeAll(in, nil) + zstdReaderPool.Put(r) + return out, err default: return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor) } diff --git a/x/mongo/driver/compression_test.go b/x/mongo/driver/compression_test.go index 5557257334..acca1317c6 100644 --- a/x/mongo/driver/compression_test.go +++ b/x/mongo/driver/compression_test.go @@ -46,6 +46,46 @@ func TestCompression(t *testing.T) { } } +func TestCompressionLevels(t *testing.T) { + errEq := func(e1, e2 error) bool { + if e1 == nil || e2 == nil { + return (e1 == nil) == (e2 == nil) + } + return e1.Error() == e2.Error() + } + + in := []byte("abc") + wr := new(bytes.Buffer) + + t.Run("ZLib", func(t *testing.T) { + opts := CompressionOpts{ + Compressor: wiremessage.CompressorZLib, + } + for lvl := zlib.HuffmanOnly - 2; lvl < zlib.BestCompression+2; lvl++ { + opts.ZlibLevel = lvl + _, err1 := CompressPayload(in, opts) + _, err2 := zlib.NewWriterLevel(wr, lvl) + if !errEq(err1, err2) { + t.Fatalf("%d: error: %v, want: %v", lvl, err1, err2) + } + } + }) + + t.Run("Zstd", func(t *testing.T) { + opts := CompressionOpts{ + Compressor: wiremessage.CompressorZstd, + } + for lvl := zstd.SpeedFastest - 2; lvl < zstd.SpeedBestCompression+2; lvl++ { + opts.ZstdLevel = int(lvl) + _, err1 := CompressPayload(in, opts) + _, err2 := zstd.NewWriter(wr, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(opts.ZstdLevel))) + if !errEq(err1, err2) { + t.Fatalf("%d: error: %v, want: %v", lvl, err1, err2) + } + } + }) +} + func TestDecompressFailures(t *testing.T) { t.Parallel() diff --git a/x/mongo/driver/testdata/compression.go b/x/mongo/driver/testdata/compression.go new file mode 100644 index 0000000000..7f355f61a4 --- /dev/null +++ b/x/mongo/driver/testdata/compression.go @@ -0,0 +1,151 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package driver + +import ( + "bytes" + "compress/zlib" + "fmt" + "io" + "sync" + + "github.com/golang/snappy" + "github.com/klauspost/compress/zstd" + "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" +) + +// CompressionOpts holds settings for how to compress a payload +type CompressionOpts struct { + Compressor wiremessage.CompressorID + ZlibLevel int + ZstdLevel int + UncompressedSize int32 +} + +var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder + +func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) { + if v, ok := zstdEncoders.Load(level); ok { + return v.(*zstd.Encoder), nil + } + encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level)) + if err != nil { + return nil, err + } + zstdEncoders.Store(level, encoder) + return encoder, nil +} + +var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder + +func getZlibEncoder(level int) (*zlibEncoder, error) { + if v, ok := zlibEncoders.Load(level); ok { + return v.(*zlibEncoder), nil + } + writer, err := zlib.NewWriterLevel(nil, level) + if err != nil { + return nil, err + } + encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)} + zlibEncoders.Store(level, encoder) + + return encoder, nil +} + +type zlibEncoder struct { + mu sync.Mutex + writer *zlib.Writer + buf *bytes.Buffer +} + +func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) { + e.mu.Lock() + defer e.mu.Unlock() + + e.buf.Reset() + e.writer.Reset(e.buf) + + _, err := e.writer.Write(src) + if err != nil { + return nil, err + } + err = e.writer.Close() + if err != nil { + return nil, err + } + dst = append(dst[:0], e.buf.Bytes()...) + return dst, nil +} + +// CompressPayload takes a byte slice and compresses it according to the options passed +func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { + switch opts.Compressor { + case wiremessage.CompressorNoOp: + return in, nil + case wiremessage.CompressorSnappy: + return snappy.Encode(nil, in), nil + case wiremessage.CompressorZLib: + encoder, err := getZlibEncoder(opts.ZlibLevel) + if err != nil { + return nil, err + } + return encoder.Encode(nil, in) + case wiremessage.CompressorZstd: + encoder, err := getZstdEncoder(zstd.EncoderLevelFromZstd(opts.ZstdLevel)) + if err != nil { + return nil, err + } + return encoder.EncodeAll(in, nil), nil + default: + return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor) + } +} + +// DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed +func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) { + switch opts.Compressor { + case wiremessage.CompressorNoOp: + return in, nil + case wiremessage.CompressorSnappy: + l, err := snappy.DecodedLen(in) + if err != nil { + return nil, fmt.Errorf("decoding compressed length %w", err) + } else if int32(l) != opts.UncompressedSize { + return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l) + } + uncompressed = make([]byte, opts.UncompressedSize) + return snappy.Decode(uncompressed, in) + case wiremessage.CompressorZLib: + r, err := zlib.NewReader(bytes.NewReader(in)) + if err != nil { + return nil, err + } + defer func() { + err = r.Close() + }() + uncompressed = make([]byte, opts.UncompressedSize) + _, err = io.ReadFull(r, uncompressed) + if err != nil { + return nil, err + } + return uncompressed, nil + case wiremessage.CompressorZstd: + r, err := zstd.NewReader(bytes.NewBuffer(in)) + if err != nil { + return nil, err + } + defer r.Close() + uncompressed = make([]byte, opts.UncompressedSize) + _, err = io.ReadFull(r, uncompressed) + if err != nil { + return nil, err + } + return uncompressed, nil + default: + return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor) + } +}