From c3e680ad8b07b736c89f5855ff9a1321d7ed2bd8 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 18 Sep 2024 11:30:34 +0200 Subject: [PATCH] zstd: Improve memory usage on small streaming encodes Very small streams will use EncodeAll internally when closing and no header has been written. This will pull a new encoder from the async buffer. Instead re-use the stream encoder. Before: ``` BenchmarkMem/flush-32 1359 837989 ns/op 7376959 B/op 109 allocs/op BenchmarkMem/no-flush-32 129 8884753 ns/op 112044489 B/op 254 allocs/op ``` After: ``` BenchmarkMem/flush-32 1254 922593 ns/op 7376966 B/op 109 allocs/op BenchmarkMem/no-flush-32 1488 841270 ns/op 7374164 B/op 29 allocs/op ``` Test is pretty much worst case, but shows the issue nicely. --- huff0/_generate/go.mod | 4 +- s2/cmd/_s2sx/go.mod | 4 +- zstd/_generate/go.mod | 4 +- zstd/blockdec.go | 4 +- zstd/encoder.go | 19 ++++--- zstd/framedec.go | 4 +- zstd/zstd_test.go | 112 +++++++++++++++++++++++++++++++++++++++++ 7 files changed, 138 insertions(+), 13 deletions(-) diff --git a/huff0/_generate/go.mod b/huff0/_generate/go.mod index aae4ebc6a7..8ece89a2f6 100644 --- a/huff0/_generate/go.mod +++ b/huff0/_generate/go.mod @@ -1,6 +1,8 @@ module github.com/klauspost/compress/s2/_generate -go 1.19 +go 1.21 + +toolchain go1.22.4 require ( github.com/klauspost/compress v1.15.15 diff --git a/s2/cmd/_s2sx/go.mod b/s2/cmd/_s2sx/go.mod index ca8b6456d6..a5e81bcffc 100644 --- a/s2/cmd/_s2sx/go.mod +++ b/s2/cmd/_s2sx/go.mod @@ -1,6 +1,8 @@ module github.com/klauspost/compress/s2/cmd/s2sx -go 1.19 +go 1.21 + +toolchain go1.22.4 require github.com/klauspost/compress v1.11.9 diff --git a/zstd/_generate/go.mod b/zstd/_generate/go.mod index aae4ebc6a7..8ece89a2f6 100644 --- a/zstd/_generate/go.mod +++ b/zstd/_generate/go.mod @@ -1,6 +1,8 @@ module github.com/klauspost/compress/s2/_generate -go 1.19 +go 1.21 + +toolchain go1.22.4 require ( github.com/klauspost/compress v1.15.15 diff --git a/zstd/blockdec.go b/zstd/blockdec.go index 03744fbc76..9c28840c3b 100644 --- a/zstd/blockdec.go +++ b/zstd/blockdec.go @@ -598,7 +598,9 @@ func (b *blockDec) prepareSequences(in []byte, hist *history) (err error) { printf("RLE set to 0x%x, code: %v", symb, v) } case compModeFSE: - println("Reading table for", tableIndex(i)) + if debugDecoder { + println("Reading table for", tableIndex(i)) + } if seq.fse == nil || seq.fse.preDefined { seq.fse = fseDecoderPool.Get().(*fseDecoder) } diff --git a/zstd/encoder.go b/zstd/encoder.go index 72af7ef0fe..a79c4a527c 100644 --- a/zstd/encoder.go +++ b/zstd/encoder.go @@ -202,7 +202,7 @@ func (e *Encoder) nextBlock(final bool) error { return nil } if final && len(s.filling) > 0 { - s.current = e.EncodeAll(s.filling, s.current[:0]) + s.current = e.encodeAll(s.encoder, s.filling, s.current[:0]) var n2 int n2, s.err = s.w.Write(s.current) if s.err != nil { @@ -469,6 +469,15 @@ func (e *Encoder) Close() error { // Data compressed with EncodeAll can be decoded with the Decoder, // using either a stream or DecodeAll. func (e *Encoder) EncodeAll(src, dst []byte) []byte { + e.init.Do(e.initialize) + enc := <-e.encoders + defer func() { + e.encoders <- enc + }() + return e.encodeAll(enc, src, dst) +} + +func (e *Encoder) encodeAll(enc encoder, src, dst []byte) []byte { if len(src) == 0 { if e.o.fullZero { // Add frame header. @@ -491,13 +500,7 @@ func (e *Encoder) EncodeAll(src, dst []byte) []byte { } return dst } - e.init.Do(e.initialize) - enc := <-e.encoders - defer func() { - // Release encoder reference to last block. - // If a non-single block is needed the encoder will reset again. - e.encoders <- enc - }() + // Use single segments when above minimum window and below window size. single := len(src) <= e.o.windowSize && len(src) > MinWindowSize if e.o.single != nil { diff --git a/zstd/framedec.go b/zstd/framedec.go index 53e160f7e5..e47af66e7c 100644 --- a/zstd/framedec.go +++ b/zstd/framedec.go @@ -146,7 +146,9 @@ func (d *frameDec) reset(br byteBuffer) error { } return err } - printf("raw: %x, mantissa: %d, exponent: %d\n", wd, wd&7, wd>>3) + if debugDecoder { + printf("raw: %x, mantissa: %d, exponent: %d\n", wd, wd&7, wd>>3) + } windowLog := 10 + (wd >> 3) windowBase := uint64(1) << windowLog windowAdd := (windowBase / 8) * uint64(wd&0x7) diff --git a/zstd/zstd_test.go b/zstd/zstd_test.go index 82cb534643..22a0a4ad09 100644 --- a/zstd/zstd_test.go +++ b/zstd/zstd_test.go @@ -6,9 +6,12 @@ package zstd import ( "flag" "fmt" + "io" + "log" "os" "runtime" "runtime/pprof" + "strings" "testing" "time" ) @@ -59,3 +62,112 @@ func TestMatchLen(t *testing.T) { a[l] = ^a[l] } } + +func TestWriterMemUsage(t *testing.T) { + testMem := func(t *testing.T, fn func()) { + var before, after runtime.MemStats + var w io.Writer + if false { + f, err := os.Create(strings.ReplaceAll(fmt.Sprintf("%s.pprof", t.Name()), "/", "_")) + if err != nil { + log.Fatal(err) + } + defer f.Close() + w = f + t.Logf("opened memory profile %s", t.Name()) + } + runtime.GC() + runtime.ReadMemStats(&before) + fn() + runtime.GC() + runtime.ReadMemStats(&after) + if w != nil { + pprof.WriteHeapProfile(w) + } + t.Log("wrote profile") + t.Logf("%s: Memory Used: %dMB, %d allocs", t.Name(), (after.HeapInuse-before.HeapInuse)/1024/1024, after.HeapObjects-before.HeapObjects) + } + data := make([]byte, 10<<20) + + t.Run("enc-all-lower", func(t *testing.T) { + for level := SpeedFastest; level <= SpeedBestCompression; level++ { + t.Run(fmt.Sprint("level-", level), func(t *testing.T) { + var zr *Encoder + var err error + dst := make([]byte, 0, len(data)*2) + testMem(t, func() { + zr, err = NewWriter(io.Discard, WithEncoderConcurrency(32), WithEncoderLevel(level), WithLowerEncoderMem(false), WithWindowSize(1<<20)) + if err != nil { + t.Fatal(err) + } + for i := 0; i < 100; i++ { + _ = zr.EncodeAll(data, dst[:0]) + } + }) + zr.Close() + }) + } + }) +} + +var data = []byte{1, 2, 3} + +func newZstdWriter() (*Encoder, error) { + return NewWriter( + io.Discard, + WithEncoderLevel(SpeedBetterCompression), + WithEncoderConcurrency(16), // we implicitly get this concurrency level if we run on 16 core CPU + WithLowerEncoderMem(false), + WithWindowSize(1<<20), + ) +} + +func BenchmarkMem(b *testing.B) { + b.Run("flush", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + w, err := newZstdWriter() + if err != nil { + b.Fatal(err) + } + + for j := 0; j < 16; j++ { + w.Reset(io.Discard) + + if _, err := w.Write(data); err != nil { + b.Fatal(err) + } + + if err := w.Flush(); err != nil { + b.Fatal(err) + } + + if err := w.Close(); err != nil { + b.Fatal(err) + } + } + } + }) + b.Run("no-flush", func(b *testing.B) { + // Will use encodeAll for block. + b.ReportAllocs() + for i := 0; i < b.N; i++ { + w, err := newZstdWriter() + if err != nil { + b.Fatal(err) + } + + for j := 0; j < 16; j++ { + w.Reset(io.Discard) + + if _, err := w.Write(data); err != nil { + b.Fatal(err) + } + + if err := w.Close(); err != nil { + b.Fatal(err) + } + } + } + }) +}