diff --git a/x/mongo/driver/compression.go b/x/mongo/driver/compression.go index 6f0db87fc8..d79b024b74 100644 --- a/x/mongo/driver/compression.go +++ b/x/mongo/driver/compression.go @@ -26,7 +26,10 @@ type CompressionOpts struct { UncompressedSize int32 } -func zstdNewWriter(lvl zstd.EncoderLevel) *zstd.Encoder { +// mustZstdNewWriter creates a zstd.Encoder with the given level and a nil +// destination writer. It panics on any errors and should only be used at +// package initialization time. +func mustZstdNewWriter(lvl zstd.EncoderLevel) *zstd.Encoder { enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(lvl)) if err != nil { panic(err) @@ -36,18 +39,18 @@ func zstdNewWriter(lvl zstd.EncoderLevel) *zstd.Encoder { var zstdEncoders = [zstd.SpeedBestCompression + 1]*zstd.Encoder{ 0: nil, // zstd.speedNotSet - zstd.SpeedFastest: zstdNewWriter(zstd.SpeedFastest), - zstd.SpeedDefault: zstdNewWriter(zstd.SpeedDefault), - zstd.SpeedBetterCompression: zstdNewWriter(zstd.SpeedBetterCompression), - zstd.SpeedBestCompression: zstdNewWriter(zstd.SpeedBestCompression), + zstd.SpeedFastest: mustZstdNewWriter(zstd.SpeedFastest), + zstd.SpeedDefault: mustZstdNewWriter(zstd.SpeedDefault), + zstd.SpeedBetterCompression: mustZstdNewWriter(zstd.SpeedBetterCompression), + zstd.SpeedBestCompression: mustZstdNewWriter(zstd.SpeedBestCompression), } func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) { if zstd.SpeedFastest <= level && level <= zstd.SpeedBestCompression { return zstdEncoders[level], nil } - // The level is invalid so call zstd.NewWriter for the error. - return zstd.NewWriter(nil, zstd.WithEncoderLevel(level)) + // The level is outside the expected range, return an error. + return nil, fmt.Errorf("invalid zstd compression level: %d", level) } // zlibEncodersOffset is the offset into the zlibEncoders array for a given @@ -68,9 +71,8 @@ func getZlibEncoder(level int) (*zlibEncoder, error) { enc := &zlibEncoder{writer: writer, level: level} return enc, nil } - // The level is invalid so call zlib.NewWriterLever for the error. - _, err := zlib.NewWriterLevel(nil, level) - return nil, err + // The level is outside the expected range, return an error. + return nil, fmt.Errorf("invalid zlib compression level: %d", level) } func putZlibEncoder(enc *zlibEncoder) { diff --git a/x/mongo/driver/compression_test.go b/x/mongo/driver/compression_test.go index acca1317c6..75a7ff072b 100644 --- a/x/mongo/driver/compression_test.go +++ b/x/mongo/driver/compression_test.go @@ -47,13 +47,6 @@ func TestCompression(t *testing.T) { } func TestCompressionLevels(t *testing.T) { - errEq := func(e1, e2 error) bool { - if e1 == nil || e2 == nil { - return (e1 == nil) == (e2 == nil) - } - return e1.Error() == e2.Error() - } - in := []byte("abc") wr := new(bytes.Buffer) @@ -65,8 +58,10 @@ func TestCompressionLevels(t *testing.T) { opts.ZlibLevel = lvl _, err1 := CompressPayload(in, opts) _, err2 := zlib.NewWriterLevel(wr, lvl) - if !errEq(err1, err2) { - t.Fatalf("%d: error: %v, want: %v", lvl, err1, err2) + if err2 != nil { + assert.Error(t, err1, "expected an error for ZLib level %d", lvl) + } else { + assert.NoError(t, err1, "unexpected error for ZLib level %d", lvl) } } }) @@ -79,8 +74,10 @@ func TestCompressionLevels(t *testing.T) { opts.ZstdLevel = int(lvl) _, err1 := CompressPayload(in, opts) _, err2 := zstd.NewWriter(wr, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(opts.ZstdLevel))) - if !errEq(err1, err2) { - t.Fatalf("%d: error: %v, want: %v", lvl, err1, err2) + if err2 != nil { + assert.Error(t, err1, "expected an error for Zstd level %d", lvl) + } else { + assert.NoError(t, err1, "unexpected error for Zstd level %d", lvl) } } })