From c299a83f950ed05c5af8177aaa6897b54b020d18 Mon Sep 17 00:00:00 2001 From: Benji Rewis Date: Mon, 22 Jul 2024 20:51:56 -0400 Subject: [PATCH] Revert "GODRIVER-2914 x/mongo/driver: enable parallel zlib compression and improve zstd decompression (#1320)" This reverts commit 84a43854bbd7c243ce5a68d7b56b3ba4ba10ed05. --- x/mongo/driver/compression.go | 114 +++++++------------ x/mongo/driver/compression_test.go | 137 ++-------------------- x/mongo/driver/testdata/compression.go | 151 ------------------------- 3 files changed, 53 insertions(+), 349 deletions(-) delete mode 100644 x/mongo/driver/testdata/compression.go diff --git a/x/mongo/driver/compression.go b/x/mongo/driver/compression.go index d79b024b74..7f355f61a4 100644 --- a/x/mongo/driver/compression.go +++ b/x/mongo/driver/compression.go @@ -26,72 +26,48 @@ type CompressionOpts struct { UncompressedSize int32 } -// mustZstdNewWriter creates a zstd.Encoder with the given level and a nil -// destination writer. It panics on any errors and should only be used at -// package initialization time. -func mustZstdNewWriter(lvl zstd.EncoderLevel) *zstd.Encoder { - enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(lvl)) - if err != nil { - panic(err) - } - return enc -} - -var zstdEncoders = [zstd.SpeedBestCompression + 1]*zstd.Encoder{ - 0: nil, // zstd.speedNotSet - zstd.SpeedFastest: mustZstdNewWriter(zstd.SpeedFastest), - zstd.SpeedDefault: mustZstdNewWriter(zstd.SpeedDefault), - zstd.SpeedBetterCompression: mustZstdNewWriter(zstd.SpeedBetterCompression), - zstd.SpeedBestCompression: mustZstdNewWriter(zstd.SpeedBestCompression), -} +var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) { - if zstd.SpeedFastest <= level && level <= zstd.SpeedBestCompression { - return zstdEncoders[level], nil + if v, ok := zstdEncoders.Load(level); ok { + return v.(*zstd.Encoder), nil + } + encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level)) + if err != nil { + return nil, err } - // The level is outside the expected range, return an error. - return nil, fmt.Errorf("invalid zstd compression level: %d", level) + zstdEncoders.Store(level, encoder) + return encoder, nil } -// zlibEncodersOffset is the offset into the zlibEncoders array for a given -// compression level. -const zlibEncodersOffset = -zlib.HuffmanOnly // HuffmanOnly == -2 - -var zlibEncoders [zlib.BestCompression + zlibEncodersOffset + 1]sync.Pool +var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder func getZlibEncoder(level int) (*zlibEncoder, error) { - if zlib.HuffmanOnly <= level && level <= zlib.BestCompression { - if enc, _ := zlibEncoders[level+zlibEncodersOffset].Get().(*zlibEncoder); enc != nil { - return enc, nil - } - writer, err := zlib.NewWriterLevel(nil, level) - if err != nil { - return nil, err - } - enc := &zlibEncoder{writer: writer, level: level} - return enc, nil + if v, ok := zlibEncoders.Load(level); ok { + return v.(*zlibEncoder), nil } - // The level is outside the expected range, return an error. - return nil, fmt.Errorf("invalid zlib compression level: %d", level) -} - -func putZlibEncoder(enc *zlibEncoder) { - if enc != nil { - zlibEncoders[enc.level+zlibEncodersOffset].Put(enc) + writer, err := zlib.NewWriterLevel(nil, level) + if err != nil { + return nil, err } + encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)} + zlibEncoders.Store(level, encoder) + + return encoder, nil } type zlibEncoder struct { + mu sync.Mutex writer *zlib.Writer - buf bytes.Buffer - level int + buf *bytes.Buffer } func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) { - defer putZlibEncoder(e) + e.mu.Lock() + defer e.mu.Unlock() e.buf.Reset() - e.writer.Reset(&e.buf) + e.writer.Reset(e.buf) _, err := e.writer.Write(src) if err != nil { @@ -129,15 +105,8 @@ func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { } } -var zstdReaderPool = sync.Pool{ - New: func() interface{} { - r, _ := zstd.NewReader(nil) - return r - }, -} - // DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed -func DecompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { +func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) { switch opts.Compressor { case wiremessage.CompressorNoOp: return in, nil @@ -148,29 +117,34 @@ func DecompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { } else if int32(l) != opts.UncompressedSize { return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l) } - out := make([]byte, opts.UncompressedSize) - return snappy.Decode(out, in) + uncompressed = make([]byte, opts.UncompressedSize) + return snappy.Decode(uncompressed, in) case wiremessage.CompressorZLib: r, err := zlib.NewReader(bytes.NewReader(in)) if err != nil { return nil, err } - out := make([]byte, opts.UncompressedSize) - if _, err := io.ReadFull(r, out); err != nil { + defer func() { + err = r.Close() + }() + uncompressed = make([]byte, opts.UncompressedSize) + _, err = io.ReadFull(r, uncompressed) + if err != nil { return nil, err } - if err := r.Close(); err != nil { + return uncompressed, nil + case wiremessage.CompressorZstd: + r, err := zstd.NewReader(bytes.NewBuffer(in)) + if err != nil { return nil, err } - return out, nil - case wiremessage.CompressorZstd: - buf := make([]byte, 0, opts.UncompressedSize) - // Using a pool here is about ~20% faster - // than using a single global zstd.Reader - r := zstdReaderPool.Get().(*zstd.Decoder) - out, err := r.DecodeAll(in, buf) - zstdReaderPool.Put(r) - return out, err + defer r.Close() + uncompressed = make([]byte, opts.UncompressedSize) + _, err = io.ReadFull(r, uncompressed) + if err != nil { + return nil, err + } + return uncompressed, nil default: return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor) } diff --git a/x/mongo/driver/compression_test.go b/x/mongo/driver/compression_test.go index 75a7ff072b..b477cb32c1 100644 --- a/x/mongo/driver/compression_test.go +++ b/x/mongo/driver/compression_test.go @@ -7,14 +7,9 @@ package driver import ( - "bytes" - "compress/zlib" "os" "testing" - "github.com/golang/snappy" - "github.com/klauspost/compress/zstd" - "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) @@ -46,43 +41,6 @@ func TestCompression(t *testing.T) { } } -func TestCompressionLevels(t *testing.T) { - in := []byte("abc") - wr := new(bytes.Buffer) - - t.Run("ZLib", func(t *testing.T) { - opts := CompressionOpts{ - Compressor: wiremessage.CompressorZLib, - } - for lvl := zlib.HuffmanOnly - 2; lvl < zlib.BestCompression+2; lvl++ { - opts.ZlibLevel = lvl - _, err1 := CompressPayload(in, opts) - _, err2 := zlib.NewWriterLevel(wr, lvl) - if err2 != nil { - assert.Error(t, err1, "expected an error for ZLib level %d", lvl) - } else { - assert.NoError(t, err1, "unexpected error for ZLib level %d", lvl) - } - } - }) - - t.Run("Zstd", func(t *testing.T) { - opts := CompressionOpts{ - Compressor: wiremessage.CompressorZstd, - } - for lvl := zstd.SpeedFastest - 2; lvl < zstd.SpeedBestCompression+2; lvl++ { - opts.ZstdLevel = int(lvl) - _, err1 := CompressPayload(in, opts) - _, err2 := zstd.NewWriter(wr, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(opts.ZstdLevel))) - if err2 != nil { - assert.Error(t, err1, "expected an error for Zstd level %d", lvl) - } else { - assert.NoError(t, err1, "unexpected error for Zstd level %d", lvl) - } - } - }) -} - func TestDecompressFailures(t *testing.T) { t.Parallel() @@ -104,57 +62,18 @@ func TestDecompressFailures(t *testing.T) { }) } -var ( - compressionPayload []byte - compressedSnappyPayload []byte - compressedZLibPayload []byte - compressedZstdPayload []byte -) - -func initCompressionPayload(b *testing.B) { - if compressionPayload != nil { - return - } - data, err := os.ReadFile("testdata/compression.go") - if err != nil { - b.Fatal(err) - } - for i := 1; i < 10; i++ { - data = append(data, data...) - } - compressionPayload = data - - compressedSnappyPayload = snappy.Encode(compressedSnappyPayload[:0], data) - - { - var buf bytes.Buffer - enc, err := zstd.NewWriter(&buf, zstd.WithEncoderLevel(zstd.SpeedDefault)) +func BenchmarkCompressPayload(b *testing.B) { + payload := func() []byte { + buf, err := os.ReadFile("compression.go") if err != nil { - b.Fatal(err) + b.Log(err) + b.FailNow() } - compressedZstdPayload = enc.EncodeAll(data, nil) - } - - { - var buf bytes.Buffer - enc := zlib.NewWriter(&buf) - if _, err := enc.Write(data); err != nil { - b.Fatal(err) + for i := 1; i < 10; i++ { + buf = append(buf, buf...) } - if err := enc.Close(); err != nil { - b.Fatal(err) - } - if err := enc.Close(); err != nil { - b.Fatal(err) - } - compressedZLibPayload = append(compressedZLibPayload[:0], buf.Bytes()...) - } - - b.ResetTimer() -} - -func BenchmarkCompressPayload(b *testing.B) { - initCompressionPayload(b) + return buf + }() compressors := []wiremessage.CompressorID{ wiremessage.CompressorSnappy, @@ -169,9 +88,6 @@ func BenchmarkCompressPayload(b *testing.B) { ZlibLevel: wiremessage.DefaultZlibLevel, ZstdLevel: wiremessage.DefaultZstdLevel, } - payload := compressionPayload - b.SetBytes(int64(len(payload))) - b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { for pb.Next() { _, err := CompressPayload(payload, opts) @@ -183,38 +99,3 @@ func BenchmarkCompressPayload(b *testing.B) { }) } } - -func BenchmarkDecompressPayload(b *testing.B) { - initCompressionPayload(b) - - benchmarks := []struct { - compressor wiremessage.CompressorID - payload []byte - }{ - {wiremessage.CompressorSnappy, compressedSnappyPayload}, - {wiremessage.CompressorZLib, compressedZLibPayload}, - {wiremessage.CompressorZstd, compressedZstdPayload}, - } - - for _, bench := range benchmarks { - b.Run(bench.compressor.String(), func(b *testing.B) { - opts := CompressionOpts{ - Compressor: bench.compressor, - ZlibLevel: wiremessage.DefaultZlibLevel, - ZstdLevel: wiremessage.DefaultZstdLevel, - UncompressedSize: int32(len(compressionPayload)), - } - payload := bench.payload - b.SetBytes(int64(len(compressionPayload))) - b.ReportAllocs() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - _, err := DecompressPayload(payload, opts) - if err != nil { - b.Fatal(err) - } - } - }) - }) - } -} diff --git a/x/mongo/driver/testdata/compression.go b/x/mongo/driver/testdata/compression.go deleted file mode 100644 index 7f355f61a4..0000000000 --- a/x/mongo/driver/testdata/compression.go +++ /dev/null @@ -1,151 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package driver - -import ( - "bytes" - "compress/zlib" - "fmt" - "io" - "sync" - - "github.com/golang/snappy" - "github.com/klauspost/compress/zstd" - "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" -) - -// CompressionOpts holds settings for how to compress a payload -type CompressionOpts struct { - Compressor wiremessage.CompressorID - ZlibLevel int - ZstdLevel int - UncompressedSize int32 -} - -var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder - -func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) { - if v, ok := zstdEncoders.Load(level); ok { - return v.(*zstd.Encoder), nil - } - encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level)) - if err != nil { - return nil, err - } - zstdEncoders.Store(level, encoder) - return encoder, nil -} - -var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder - -func getZlibEncoder(level int) (*zlibEncoder, error) { - if v, ok := zlibEncoders.Load(level); ok { - return v.(*zlibEncoder), nil - } - writer, err := zlib.NewWriterLevel(nil, level) - if err != nil { - return nil, err - } - encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)} - zlibEncoders.Store(level, encoder) - - return encoder, nil -} - -type zlibEncoder struct { - mu sync.Mutex - writer *zlib.Writer - buf *bytes.Buffer -} - -func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) { - e.mu.Lock() - defer e.mu.Unlock() - - e.buf.Reset() - e.writer.Reset(e.buf) - - _, err := e.writer.Write(src) - if err != nil { - return nil, err - } - err = e.writer.Close() - if err != nil { - return nil, err - } - dst = append(dst[:0], e.buf.Bytes()...) - return dst, nil -} - -// CompressPayload takes a byte slice and compresses it according to the options passed -func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { - switch opts.Compressor { - case wiremessage.CompressorNoOp: - return in, nil - case wiremessage.CompressorSnappy: - return snappy.Encode(nil, in), nil - case wiremessage.CompressorZLib: - encoder, err := getZlibEncoder(opts.ZlibLevel) - if err != nil { - return nil, err - } - return encoder.Encode(nil, in) - case wiremessage.CompressorZstd: - encoder, err := getZstdEncoder(zstd.EncoderLevelFromZstd(opts.ZstdLevel)) - if err != nil { - return nil, err - } - return encoder.EncodeAll(in, nil), nil - default: - return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor) - } -} - -// DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed -func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) { - switch opts.Compressor { - case wiremessage.CompressorNoOp: - return in, nil - case wiremessage.CompressorSnappy: - l, err := snappy.DecodedLen(in) - if err != nil { - return nil, fmt.Errorf("decoding compressed length %w", err) - } else if int32(l) != opts.UncompressedSize { - return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l) - } - uncompressed = make([]byte, opts.UncompressedSize) - return snappy.Decode(uncompressed, in) - case wiremessage.CompressorZLib: - r, err := zlib.NewReader(bytes.NewReader(in)) - if err != nil { - return nil, err - } - defer func() { - err = r.Close() - }() - uncompressed = make([]byte, opts.UncompressedSize) - _, err = io.ReadFull(r, uncompressed) - if err != nil { - return nil, err - } - return uncompressed, nil - case wiremessage.CompressorZstd: - r, err := zstd.NewReader(bytes.NewBuffer(in)) - if err != nil { - return nil, err - } - defer r.Close() - uncompressed = make([]byte, opts.UncompressedSize) - _, err = io.ReadFull(r, uncompressed) - if err != nil { - return nil, err - } - return uncompressed, nil - default: - return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor) - } -}