Skip to content

Commit

Permalink
Code review changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdale committed Sep 5, 2023
1 parent 9bbdb31 commit 51a6c7b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 21 deletions.
22 changes: 12 additions & 10 deletions x/mongo/driver/compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ type CompressionOpts struct {
UncompressedSize int32
}

func zstdNewWriter(lvl zstd.EncoderLevel) *zstd.Encoder {
// mustZstdNewWriter creates a zstd.Encoder with the given level and a nil
// destination writer. It panics on any errors and should only be used at
// package initialization time.
func mustZstdNewWriter(lvl zstd.EncoderLevel) *zstd.Encoder {
enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(lvl))
if err != nil {
panic(err)
Expand All @@ -36,18 +39,18 @@ func zstdNewWriter(lvl zstd.EncoderLevel) *zstd.Encoder {

var zstdEncoders = [zstd.SpeedBestCompression + 1]*zstd.Encoder{
0: nil, // zstd.speedNotSet
zstd.SpeedFastest: zstdNewWriter(zstd.SpeedFastest),
zstd.SpeedDefault: zstdNewWriter(zstd.SpeedDefault),
zstd.SpeedBetterCompression: zstdNewWriter(zstd.SpeedBetterCompression),
zstd.SpeedBestCompression: zstdNewWriter(zstd.SpeedBestCompression),
zstd.SpeedFastest: mustZstdNewWriter(zstd.SpeedFastest),
zstd.SpeedDefault: mustZstdNewWriter(zstd.SpeedDefault),
zstd.SpeedBetterCompression: mustZstdNewWriter(zstd.SpeedBetterCompression),
zstd.SpeedBestCompression: mustZstdNewWriter(zstd.SpeedBestCompression),
}

func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) {
if zstd.SpeedFastest <= level && level <= zstd.SpeedBestCompression {
return zstdEncoders[level], nil
}
// The level is invalid so call zstd.NewWriter for the error.
return zstd.NewWriter(nil, zstd.WithEncoderLevel(level))
// The level is outside the expected range, return an error.
return nil, fmt.Errorf("invalid zstd compression level: %d", level)
}

// zlibEncodersOffset is the offset into the zlibEncoders array for a given
Expand All @@ -68,9 +71,8 @@ func getZlibEncoder(level int) (*zlibEncoder, error) {
enc := &zlibEncoder{writer: writer, level: level}
return enc, nil
}
// The level is invalid so call zlib.NewWriterLever for the error.
_, err := zlib.NewWriterLevel(nil, level)
return nil, err
// The level is outside the expected range, return an error.
return nil, fmt.Errorf("invalid zlib compression level: %d", level)
}

func putZlibEncoder(enc *zlibEncoder) {
Expand Down
19 changes: 8 additions & 11 deletions x/mongo/driver/compression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,6 @@ func TestCompression(t *testing.T) {
}

func TestCompressionLevels(t *testing.T) {
errEq := func(e1, e2 error) bool {
if e1 == nil || e2 == nil {
return (e1 == nil) == (e2 == nil)
}
return e1.Error() == e2.Error()
}

in := []byte("abc")
wr := new(bytes.Buffer)

Expand All @@ -65,8 +58,10 @@ func TestCompressionLevels(t *testing.T) {
opts.ZlibLevel = lvl
_, err1 := CompressPayload(in, opts)
_, err2 := zlib.NewWriterLevel(wr, lvl)
if !errEq(err1, err2) {
t.Fatalf("%d: error: %v, want: %v", lvl, err1, err2)
if err2 != nil {
assert.Error(t, err1, "expected an error for ZLib level %d", lvl)
} else {
assert.NoError(t, err1, "unexpected error for ZLib level %d", lvl)
}
}
})
Expand All @@ -79,8 +74,10 @@ func TestCompressionLevels(t *testing.T) {
opts.ZstdLevel = int(lvl)
_, err1 := CompressPayload(in, opts)
_, err2 := zstd.NewWriter(wr, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(opts.ZstdLevel)))
if !errEq(err1, err2) {
t.Fatalf("%d: error: %v, want: %v", lvl, err1, err2)
if err2 != nil {
assert.Error(t, err1, "expected an error for Zstd level %d", lvl)
} else {
assert.NoError(t, err1, "unexpected error for Zstd level %d", lvl)
}
}
})
Expand Down

0 comments on commit 51a6c7b

Please sign in to comment.