diff --git a/s2/README.md b/s2/README.md index 1d80c42a53..8284bb0810 100644 --- a/s2/README.md +++ b/s2/README.md @@ -20,11 +20,12 @@ This is important, so you don't have to worry about spending CPU cycles on alrea * Concurrent stream compression * Faster decompression, even for Snappy compatible content * Concurrent Snappy/S2 stream decompression -* Ability to quickly skip forward in compressed stream +* Skip forward in compressed stream * Random seeking with indexes * Compatible with reading Snappy compressed content * Smaller block size overhead on incompressible blocks * Block concatenation +* Block Dictionary support * Uncompressed stream mode * Automatic stream size padding * Snappy compatible block compression @@ -594,6 +595,123 @@ Best... 10737418240 -> 4210602774 [39.21%]; 42.96s, 254.4MB/s Decompression speed should be around the same as using the 'better' compression mode. +## Dictionaries + +*Note: S2 dictionary compression is currently at an early implementation stage, with no assembly for +neither encoding nor decoding. Performance improvements can be expected in the future.* + +Adding dictionaries allow providing a custom dictionary that will serve as lookup in the beginning of blocks. + +The same dictionary *must* be used for both encoding and decoding. +S2 does not keep track of whether the same dictionary is used, +and using the wrong dictionary will most often not result in an error when decompressing. + +Blocks encoded *without* dictionaries can be decompressed seamlessly *with* a dictionary. +This means it is possible to switch from an encoding without dictionaries to an encoding with dictionaries +and treat the blocks similarly. + +Similar to [zStandard dictionaries](https://github.com/facebook/zstd#the-case-for-small-data-compression), +the same usage scenario applies to S2 dictionaries. + +> Training works if there is some correlation in a family of small data samples. The more data-specific a dictionary is, the more efficient it is (there is no universal dictionary). Hence, deploying one dictionary per type of data will provide the greatest benefits. Dictionary gains are mostly effective in the first few KB. Then, the compression algorithm will gradually use previously decoded content to better compress the rest of the file. + +S2 further limits the dictionary to only be enabled on the first 64KB of a block. +This will remove any negative (speed) impacts of the dictionaries on bigger blocks. + +### Compression + +Using the [github_users_sample_set](https://github.com/facebook/zstd/releases/download/v1.1.3/github_users_sample_set.tar.zst) +and a 64KB dictionary trained with zStandard the following sizes can be achieved. + +| | Default | Better | Best | +|--------------------|------------------|------------------|-----------------------| +| Without Dictionary | 3362023 (44.92%) | 3083163 (41.19%) | 3057944 (40.86%) | +| With Dictionary | 921524 (12.31%) | 873154 (11.67%) | 785503 bytes (10.49%) | + +So for highly repetitive content, this case provides an almost 3x reduction in size. + +For less uniform data we will use the Go source code tree. +Compressing First 64KB of all `.go` files in `go/src`, Go 1.19.5, 8912 files, 51253563 bytes input: + +| | Default | Better | Best | +|--------------------|-------------------|-------------------|-------------------| +| Without Dictionary | 22955767 (44.79%) | 20189613 (39.39% | 19482828 (38.01%) | +| With Dictionary | 19654568 (38.35%) | 16289357 (31.78%) | 15184589 (29.63%) | +| Saving/file | 362 bytes | 428 bytes | 472 bytes | + + +### Creating Dictionaries + +There are no tools to create dictionaries in S2. +However, there are multiple ways to create a useful dictionary: + +#### Using a Sample File + +If your input is very uniform, you can just use a sample file as the dictionary. + +For example in the `github_users_sample_set` above, the average compression only goes up from +10.49% to 11.48% by using the first file as dictionary compared to using a dedicated dictionary. + +```Go + // Read a sample + sample, err := os.ReadFile("sample.json") + + // Create a dictionary. + dict := s2.MakeDict(sample, nil) + + // b := dict.Bytes() will provide a dictionary that can be saved + // and reloaded with s2.NewDict(b). + + // To encode: + encoded := dict.Encode(nil, file) + + // To decode: + decoded, err := dict.Decode(nil, file) +``` + +#### Using Zstandard + +Zstandard dictionaries can easily be converted to S2 dictionaries. + +This can be helpful to generate dictionaries for files that don't have a fixed structure. + + +Example, with training set files placed in `./training-set`: + +`λ zstd -r --train-fastcover training-set/* --maxdict=65536 -o name.dict` + +This will create a dictionary of 64KB, that can be converted to a dictionary like this: + +```Go + // Decode the Zstandard dictionary. + insp, err := zstd.InspectDictionary(zdict) + if err != nil { + panic(err) + } + + // We are only interested in the contents. + // Assume that files start with "// Copyright (c) 2023". + // Search for the longest match for that. + // This may save a few bytes. + dict := s2.MakeDict(insp.Content(), []byte("// Copyright (c) 2023")) + + // b := dict.Bytes() will provide a dictionary that can be saved + // and reloaded with s2.NewDict(b). + + // We can now encode using this dictionary + encodedWithDict := dict.Encode(nil, payload) + + // To decode content: + decoded, err := dict.Decode(nil, encodedWithDict) +``` + +It is recommended to save the dictionary returned by ` b:= dict.Bytes()`, since that will contain only the S2 dictionary. + +This dictionary can later be loaded using `s2.NewDict(b)`. The dictionary then no longer requires `zstd` to be initialized. + +Also note how `s2.MakeDict` allows you to search for a common starting sequence of your files. +This can be omitted, at the expense of a few bytes. + # Snappy Compatibility S2 now offers full compatibility with Snappy. @@ -929,6 +1047,72 @@ The first copy of a block cannot be a repeat offset and the offset is reset on e Default streaming block size is 1MB. +# Dictionary Encoding + +Adding dictionaries allow providing a custom dictionary that will serve as lookup in the beginning of blocks. + +A dictionary provides an initial repeat value that can be used to point to a common header. + +Other than that the dictionary contains values that can be used as back-references. + +Often used data should be placed at the *end* of the dictionary since offsets < 2048 bytes will be smaller. + +## Format + +Dictionary *content* must at least 16 bytes and less or equal to 64KiB (65536 bytes). + +Encoding: `[repeat value (uvarint)][dictionary content...]` + +Before the dictionary content, an unsigned base-128 (uvarint) encoded value specifying the initial repeat offset. +This value is an offset into the dictionary content and not a back-reference offset, +so setting this to 0 will make the repeat value point to the first value of the dictionary. + +The value must be less than the dictionary length-8 + +## Encoding + +From the decoder point of view the dictionary content is seen as preceding the encoded content. + +`[dictionary content][decoded output]` + +Backreferences to the dictionary are encoded as ordinary backreferences that have an offset before the start of the decoded block. + +Matches copying from the dictionary are **not** allowed to cross from the dictionary into the decoded data. +However, if a copy ends at the end of the dictionary the next repeat will point to the start of the decoded buffer, which is allowed. + +The first match can be a repeat value, which will use the repeat offset stored in the dictionary. + +When 64KB (65536 bytes) has been en/decoded it is no longer allowed to reference the dictionary, +neither by a copy nor repeat operations. +If the boundary is crossed while copying from the dictionary, the operation should complete, +but the next instruction is not allowed to reference the dictionary. + +Valid blocks encoded *without* a dictionary can be decoded with any dictionary. +There are no checks whether the supplied dictionary is the correct for a block. +Because of this there is no overhead by using a dictionary. + +## Example + +This is the dictionary content. Elements are separated by `[]`. + +Dictionary: `[0x0a][Yesterday 25 bananas were added to Benjamins brown bag]`. + +Initial repeat offset is set at 10, which is the letter `2`. + +Encoded `[LIT "10"][REPEAT len=10][LIT "hich"][MATCH off=50 len=6][MATCH off=31 len=6][MATCH off=61 len=10]` + +Decoded: `[10][ bananas w][hich][ were ][brown ][were added]` + +Output: `10 bananas which were brown were added` + + +## Streams + +For streams each block can use the dictionary. + +The dictionary cannot not currently be provided on the stream. + + # LICENSE This code is based on the [Snappy-Go](https://github.com/golang/snappy) implementation. diff --git a/s2/_generate/gen.go b/s2/_generate/gen.go index ec0d83a8ad..c5be5443de 100644 --- a/s2/_generate/gen.go +++ b/s2/_generate/gen.go @@ -1067,14 +1067,12 @@ func (o options) genEncodeBetterBlockAsm(name string, lTableBits, sTableBits, sk SUBL(repeatL, rep) // rep = s - repeat // if uint32(cv>>(checkRep*8)) == load32(src, s-repeat+checkRep) { - left, right := GP64(), GP64() - MOVQ(Mem{Base: src, Index: rep, Disp: 0, Scale: 1}, right.As64()) - MOVQ(cv, left) tmp := GP64() - MOVQ(U64(repeatMask), tmp) - ANDQ(tmp, left) - ANDQ(tmp, right) - CMPQ(left.As64(), right.As64()) + mask := GP64() + MOVQ(Mem{Base: src, Index: rep, Disp: 0, Scale: 1}, tmp.As64()) + MOVQ(U64(repeatMask), mask) + XORQ(cv.As64(), tmp.As64()) + TESTQ(mask.As64(), tmp.As64()) // BAIL, no repeat. JNE(LabelRef("no_repeat_found_" + name)) } diff --git a/s2/decode.go b/s2/decode.go index a89db8f96a..b7c9adfdd8 100644 --- a/s2/decode.go +++ b/s2/decode.go @@ -13,6 +13,7 @@ import ( "io/ioutil" "math" "runtime" + "strconv" "sync" ) @@ -1111,3 +1112,370 @@ func (r *Reader) SkippableCB(id uint8, fn func(r io.Reader) error) error { r.skippableCB[id] = fn return nil } + +// s2DecodeDict writes the decoding of src to dst. It assumes that the varint-encoded +// length of the decompressed bytes has already been read, and that len(dst) +// equals that length. +// +// It returns 0 on success or a decodeErrCodeXxx error code on failure. +func s2DecodeDict(dst, src []byte, dict *Dict) int { + if dict == nil { + return s2Decode(dst, src) + } + const debug = false + const debugErrs = debug + + if debug { + fmt.Println("Starting decode, dst len:", len(dst)) + } + var d, s, length int + offset := len(dict.dict) - dict.repeat + + // As long as we can read at least 5 bytes... + for s < len(src)-5 { + // Removing bounds checks is SLOWER, when if doing + // in := src[s:s+5] + // Checked on Go 1.18 + switch src[s] & 0x03 { + case tagLiteral: + x := uint32(src[s] >> 2) + switch { + case x < 60: + s++ + case x == 60: + s += 2 + x = uint32(src[s-1]) + case x == 61: + in := src[s : s+3] + x = uint32(in[1]) | uint32(in[2])<<8 + s += 3 + case x == 62: + in := src[s : s+4] + // Load as 32 bit and shift down. + x = uint32(in[0]) | uint32(in[1])<<8 | uint32(in[2])<<16 | uint32(in[3])<<24 + x >>= 8 + s += 4 + case x == 63: + in := src[s : s+5] + x = uint32(in[1]) | uint32(in[2])<<8 | uint32(in[3])<<16 | uint32(in[4])<<24 + s += 5 + } + length = int(x) + 1 + if debug { + fmt.Println("literals, length:", length, "d-after:", d+length) + } + if length > len(dst)-d || length > len(src)-s || (strconv.IntSize == 32 && length <= 0) { + if debugErrs { + fmt.Println("corrupt literal: length:", length, "d-left:", len(dst)-d, "src-left:", len(src)-s) + } + return decodeErrCodeCorrupt + } + + copy(dst[d:], src[s:s+length]) + d += length + s += length + continue + + case tagCopy1: + s += 2 + toffset := int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1])) + length = int(src[s-2]) >> 2 & 0x7 + if toffset == 0 { + if debug { + fmt.Print("(repeat) ") + } + // keep last offset + switch length { + case 5: + length = int(src[s]) + 4 + s += 1 + case 6: + in := src[s : s+2] + length = int(uint32(in[0])|(uint32(in[1])<<8)) + (1 << 8) + s += 2 + case 7: + in := src[s : s+3] + length = int((uint32(in[2])<<16)|(uint32(in[1])<<8)|uint32(in[0])) + (1 << 16) + s += 3 + default: // 0-> 4 + } + } else { + offset = toffset + } + length += 4 + case tagCopy2: + in := src[s : s+3] + offset = int(uint32(in[1]) | uint32(in[2])<<8) + length = 1 + int(in[0])>>2 + s += 3 + + case tagCopy4: + in := src[s : s+5] + offset = int(uint32(in[1]) | uint32(in[2])<<8 | uint32(in[3])<<16 | uint32(in[4])<<24) + length = 1 + int(in[0])>>2 + s += 5 + } + + if offset <= 0 || length > len(dst)-d { + if debugErrs { + fmt.Println("match error; offset:", offset, "length:", length, "dst-left:", len(dst)-d) + } + return decodeErrCodeCorrupt + } + + // copy from dict + if d < offset { + if d > MaxDictSrcOffset { + if debugErrs { + fmt.Println("dict after", MaxDictSrcOffset, "d:", d, "offset:", offset, "length:", length) + } + return decodeErrCodeCorrupt + } + startOff := len(dict.dict) - offset + d + if startOff < 0 || startOff+length > len(dict.dict) { + if debugErrs { + fmt.Printf("offset (%d) + length (%d) bigger than dict (%d)\n", offset, length, len(dict.dict)) + } + return decodeErrCodeCorrupt + } + if debug { + fmt.Println("dict copy, length:", length, "offset:", offset, "d-after:", d+length, "dict start offset:", startOff) + } + copy(dst[d:d+length], dict.dict[startOff:]) + d += length + continue + } + + if debug { + fmt.Println("copy, length:", length, "offset:", offset, "d-after:", d+length) + } + + // Copy from an earlier sub-slice of dst to a later sub-slice. + // If no overlap, use the built-in copy: + if offset > length { + copy(dst[d:d+length], dst[d-offset:]) + d += length + continue + } + + // Unlike the built-in copy function, this byte-by-byte copy always runs + // forwards, even if the slices overlap. Conceptually, this is: + // + // d += forwardCopy(dst[d:d+length], dst[d-offset:]) + // + // We align the slices into a and b and show the compiler they are the same size. + // This allows the loop to run without bounds checks. + a := dst[d : d+length] + b := dst[d-offset:] + b = b[:len(a)] + for i := range a { + a[i] = b[i] + } + d += length + } + + // Remaining with extra checks... + for s < len(src) { + switch src[s] & 0x03 { + case tagLiteral: + x := uint32(src[s] >> 2) + switch { + case x < 60: + s++ + case x == 60: + s += 2 + if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. + if debugErrs { + fmt.Println("src went oob") + } + return decodeErrCodeCorrupt + } + x = uint32(src[s-1]) + case x == 61: + s += 3 + if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. + if debugErrs { + fmt.Println("src went oob") + } + return decodeErrCodeCorrupt + } + x = uint32(src[s-2]) | uint32(src[s-1])<<8 + case x == 62: + s += 4 + if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. + if debugErrs { + fmt.Println("src went oob") + } + return decodeErrCodeCorrupt + } + x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16 + case x == 63: + s += 5 + if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. + if debugErrs { + fmt.Println("src went oob") + } + return decodeErrCodeCorrupt + } + x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24 + } + length = int(x) + 1 + if length > len(dst)-d || length > len(src)-s || (strconv.IntSize == 32 && length <= 0) { + if debugErrs { + fmt.Println("corrupt literal: length:", length, "d-left:", len(dst)-d, "src-left:", len(src)-s) + } + return decodeErrCodeCorrupt + } + if debug { + fmt.Println("literals, length:", length, "d-after:", d+length) + } + + copy(dst[d:], src[s:s+length]) + d += length + s += length + continue + + case tagCopy1: + s += 2 + if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. + if debugErrs { + fmt.Println("src went oob") + } + return decodeErrCodeCorrupt + } + length = int(src[s-2]) >> 2 & 0x7 + toffset := int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1])) + if toffset == 0 { + if debug { + fmt.Print("(repeat) ") + } + // keep last offset + switch length { + case 5: + s += 1 + if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. + if debugErrs { + fmt.Println("src went oob") + } + return decodeErrCodeCorrupt + } + length = int(uint32(src[s-1])) + 4 + case 6: + s += 2 + if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. + if debugErrs { + fmt.Println("src went oob") + } + return decodeErrCodeCorrupt + } + length = int(uint32(src[s-2])|(uint32(src[s-1])<<8)) + (1 << 8) + case 7: + s += 3 + if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. + if debugErrs { + fmt.Println("src went oob") + } + return decodeErrCodeCorrupt + } + length = int(uint32(src[s-3])|(uint32(src[s-2])<<8)|(uint32(src[s-1])<<16)) + (1 << 16) + default: // 0-> 4 + } + } else { + offset = toffset + } + length += 4 + case tagCopy2: + s += 3 + if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. + if debugErrs { + fmt.Println("src went oob") + } + return decodeErrCodeCorrupt + } + length = 1 + int(src[s-3])>>2 + offset = int(uint32(src[s-2]) | uint32(src[s-1])<<8) + + case tagCopy4: + s += 5 + if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. + if debugErrs { + fmt.Println("src went oob") + } + return decodeErrCodeCorrupt + } + length = 1 + int(src[s-5])>>2 + offset = int(uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24) + } + + if offset <= 0 || length > len(dst)-d { + if debugErrs { + fmt.Println("match error; offset:", offset, "length:", length, "dst-left:", len(dst)-d) + } + return decodeErrCodeCorrupt + } + + // copy from dict + if d < offset { + if d > MaxDictSrcOffset { + if debugErrs { + fmt.Println("dict after", MaxDictSrcOffset, "d:", d, "offset:", offset, "length:", length) + } + return decodeErrCodeCorrupt + } + rOff := len(dict.dict) - (offset - d) + if debug { + fmt.Println("starting dict entry from dict offset", len(dict.dict)-rOff) + } + if rOff+length > len(dict.dict) { + if debugErrs { + fmt.Println("err: END offset", rOff+length, "bigger than dict", len(dict.dict), "dict offset:", rOff, "length:", length) + } + return decodeErrCodeCorrupt + } + if rOff < 0 { + if debugErrs { + fmt.Println("err: START offset", rOff, "less than 0", len(dict.dict), "dict offset:", rOff, "length:", length) + } + return decodeErrCodeCorrupt + } + copy(dst[d:d+length], dict.dict[rOff:]) + d += length + continue + } + + if debug { + fmt.Println("copy, length:", length, "offset:", offset, "d-after:", d+length) + } + + // Copy from an earlier sub-slice of dst to a later sub-slice. + // If no overlap, use the built-in copy: + if offset > length { + copy(dst[d:d+length], dst[d-offset:]) + d += length + continue + } + + // Unlike the built-in copy function, this byte-by-byte copy always runs + // forwards, even if the slices overlap. Conceptually, this is: + // + // d += forwardCopy(dst[d:d+length], dst[d-offset:]) + // + // We align the slices into a and b and show the compiler they are the same size. + // This allows the loop to run without bounds checks. + a := dst[d : d+length] + b := dst[d-offset:] + b = b[:len(a)] + for i := range a { + a[i] = b[i] + } + d += length + } + + if d != len(dst) { + if debugErrs { + fmt.Println("wanted length", len(dst), "got", d) + } + return decodeErrCodeCorrupt + } + return 0 +} diff --git a/s2/dict.go b/s2/dict.go new file mode 100644 index 0000000000..24f7ce80bc --- /dev/null +++ b/s2/dict.go @@ -0,0 +1,331 @@ +// Copyright (c) 2022+ Klaus Post. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package s2 + +import ( + "bytes" + "encoding/binary" + "sync" +) + +const ( + // MinDictSize is the minimum dictionary size when repeat has been read. + MinDictSize = 16 + + // MaxDictSize is the maximum dictionary size when repeat has been read. + MaxDictSize = 65536 + + // MaxDictSrcOffset is the maximum offset where a dictionary entry can start. + MaxDictSrcOffset = 65535 +) + +// Dict contains a dictionary that can be used for encoding and decoding s2 +type Dict struct { + dict []byte + repeat int // Repeat as index of dict + + fast, better, best sync.Once + fastTable *[1 << 14]uint16 + + betterTableShort *[1 << 14]uint16 + betterTableLong *[1 << 17]uint16 + + bestTableShort *[1 << 16]uint32 + bestTableLong *[1 << 19]uint32 +} + +// NewDict will read a dictionary. +// It will return nil if the dictionary is invalid. +func NewDict(dict []byte) *Dict { + if len(dict) == 0 { + return nil + } + var d Dict + // Repeat is the first value of the dict + r, n := binary.Uvarint(dict) + if n <= 0 { + return nil + } + dict = dict[n:] + d.dict = dict + if cap(d.dict) < len(d.dict)+16 { + d.dict = append(make([]byte, 0, len(d.dict)+16), d.dict...) + } + if len(dict) < MinDictSize || len(dict) > MaxDictSize { + return nil + } + d.repeat = int(r) + if d.repeat > len(dict) { + return nil + } + return &d +} + +// Bytes will return a serialized version of the dictionary. +// The output can be sent to NewDict. +func (d *Dict) Bytes() []byte { + dst := make([]byte, binary.MaxVarintLen16+len(d.dict)) + return append(dst[:binary.PutUvarint(dst, uint64(d.repeat))], d.dict...) +} + +// MakeDict will create a dictionary. +// 'data' must be at least MinDictSize. +// If data is longer than MaxDictSize only the last MaxDictSize bytes will be used. +// If searchStart is set the start repeat value will be set to the last +// match of this content. +// If no matches are found, it will attempt to find shorter matches. +// This content should match the typical start of a block. +// If at least 4 bytes cannot be matched, repeat is set to start of block. +func MakeDict(data []byte, searchStart []byte) *Dict { + if len(data) == 0 { + return nil + } + if len(data) > MaxDictSize { + data = data[len(data)-MaxDictSize:] + } + var d Dict + dict := data + d.dict = dict + if cap(d.dict) < len(d.dict)+16 { + d.dict = append(make([]byte, 0, len(d.dict)+16), d.dict...) + } + if len(dict) < MinDictSize { + return nil + } + + // Find the longest match possible, last entry if multiple. + for s := len(searchStart); s > 4; s-- { + if idx := bytes.LastIndex(data, searchStart[:s]); idx >= 0 && idx <= len(data)-8 { + d.repeat = idx + break + } + } + + return &d +} + +// Encode returns the encoded form of src. The returned slice may be a sub- +// slice of dst if dst was large enough to hold the entire encoded block. +// Otherwise, a newly allocated slice will be returned. +// +// The dst and src must not overlap. It is valid to pass a nil dst. +// +// The blocks will require the same amount of memory to decode as encoding, +// and does not make for concurrent decoding. +// Also note that blocks do not contain CRC information, so corruption may be undetected. +// +// If you need to encode larger amounts of data, consider using +// the streaming interface which gives all of these features. +func (d *Dict) Encode(dst, src []byte) []byte { + if n := MaxEncodedLen(len(src)); n < 0 { + panic(ErrTooLarge) + } else if cap(dst) < n { + dst = make([]byte, n) + } else { + dst = dst[:n] + } + + // The block starts with the varint-encoded length of the decompressed bytes. + dstP := binary.PutUvarint(dst, uint64(len(src))) + + if len(src) == 0 { + return dst[:dstP] + } + if len(src) < minNonLiteralBlockSize { + dstP += emitLiteral(dst[dstP:], src) + return dst[:dstP] + } + n := encodeBlockDictGo(dst[dstP:], src, d) + if n > 0 { + dstP += n + return dst[:dstP] + } + // Not compressible + dstP += emitLiteral(dst[dstP:], src) + return dst[:dstP] +} + +// EncodeBetter returns the encoded form of src. The returned slice may be a sub- +// slice of dst if dst was large enough to hold the entire encoded block. +// Otherwise, a newly allocated slice will be returned. +// +// EncodeBetter compresses better than Encode but typically with a +// 10-40% speed decrease on both compression and decompression. +// +// The dst and src must not overlap. It is valid to pass a nil dst. +// +// The blocks will require the same amount of memory to decode as encoding, +// and does not make for concurrent decoding. +// Also note that blocks do not contain CRC information, so corruption may be undetected. +// +// If you need to encode larger amounts of data, consider using +// the streaming interface which gives all of these features. +func (d *Dict) EncodeBetter(dst, src []byte) []byte { + if n := MaxEncodedLen(len(src)); n < 0 { + panic(ErrTooLarge) + } else if len(dst) < n { + dst = make([]byte, n) + } + + // The block starts with the varint-encoded length of the decompressed bytes. + dstP := binary.PutUvarint(dst, uint64(len(src))) + + if len(src) == 0 { + return dst[:dstP] + } + if len(src) < minNonLiteralBlockSize { + dstP += emitLiteral(dst[dstP:], src) + return dst[:dstP] + } + n := encodeBlockBetterDict(dst[dstP:], src, d) + if n > 0 { + dstP += n + return dst[:dstP] + } + // Not compressible + dstP += emitLiteral(dst[dstP:], src) + return dst[:dstP] +} + +// EncodeBest returns the encoded form of src. The returned slice may be a sub- +// slice of dst if dst was large enough to hold the entire encoded block. +// Otherwise, a newly allocated slice will be returned. +// +// EncodeBest compresses as good as reasonably possible but with a +// big speed decrease. +// +// The dst and src must not overlap. It is valid to pass a nil dst. +// +// The blocks will require the same amount of memory to decode as encoding, +// and does not make for concurrent decoding. +// Also note that blocks do not contain CRC information, so corruption may be undetected. +// +// If you need to encode larger amounts of data, consider using +// the streaming interface which gives all of these features. +func (d *Dict) EncodeBest(dst, src []byte) []byte { + if n := MaxEncodedLen(len(src)); n < 0 { + panic(ErrTooLarge) + } else if len(dst) < n { + dst = make([]byte, n) + } + + // The block starts with the varint-encoded length of the decompressed bytes. + dstP := binary.PutUvarint(dst, uint64(len(src))) + + if len(src) == 0 { + return dst[:dstP] + } + if len(src) < minNonLiteralBlockSize { + dstP += emitLiteral(dst[dstP:], src) + return dst[:dstP] + } + n := encodeBlockBest(dst[dstP:], src, d) + if n > 0 { + dstP += n + return dst[:dstP] + } + // Not compressible + dstP += emitLiteral(dst[dstP:], src) + return dst[:dstP] +} + +// Decode returns the decoded form of src. The returned slice may be a sub- +// slice of dst if dst was large enough to hold the entire decoded block. +// Otherwise, a newly allocated slice will be returned. +// +// The dst and src must not overlap. It is valid to pass a nil dst. +func (d *Dict) Decode(dst, src []byte) ([]byte, error) { + dLen, s, err := decodedLen(src) + if err != nil { + return nil, err + } + if dLen <= cap(dst) { + dst = dst[:dLen] + } else { + dst = make([]byte, dLen) + } + if s2DecodeDict(dst, src[s:], d) != 0 { + return nil, ErrCorrupt + } + return dst, nil +} + +func (d *Dict) initFast() { + d.fast.Do(func() { + const ( + tableBits = 14 + maxTableSize = 1 << tableBits + ) + + var table [maxTableSize]uint16 + // We stop so any entry of length 8 can always be read. + for i := 0; i < len(d.dict)-8-2; i += 3 { + x0 := load64(d.dict, i) + h0 := hash6(x0, tableBits) + h1 := hash6(x0>>8, tableBits) + h2 := hash6(x0>>16, tableBits) + table[h0] = uint16(i) + table[h1] = uint16(i + 1) + table[h2] = uint16(i + 2) + } + d.fastTable = &table + }) +} + +func (d *Dict) initBetter() { + d.better.Do(func() { + const ( + // Long hash matches. + lTableBits = 17 + maxLTableSize = 1 << lTableBits + + // Short hash matches. + sTableBits = 14 + maxSTableSize = 1 << sTableBits + ) + + var lTable [maxLTableSize]uint16 + var sTable [maxSTableSize]uint16 + + // We stop so any entry of length 8 can always be read. + for i := 0; i < len(d.dict)-8; i++ { + cv := load64(d.dict, i) + lTable[hash7(cv, lTableBits)] = uint16(i) + sTable[hash4(cv, sTableBits)] = uint16(i) + } + d.betterTableShort = &sTable + d.betterTableLong = &lTable + }) +} + +func (d *Dict) initBest() { + d.best.Do(func() { + const ( + // Long hash matches. + lTableBits = 19 + maxLTableSize = 1 << lTableBits + + // Short hash matches. + sTableBits = 16 + maxSTableSize = 1 << sTableBits + ) + + var lTable [maxLTableSize]uint32 + var sTable [maxSTableSize]uint32 + + // We stop so any entry of length 8 can always be read. + for i := 0; i < len(d.dict)-8; i++ { + cv := load64(d.dict, i) + hashL := hash8(cv, lTableBits) + hashS := hash4(cv, sTableBits) + candidateL := lTable[hashL] + candidateS := sTable[hashS] + lTable[hashL] = uint32(i) | candidateL<<16 + sTable[hashS] = uint32(i) | candidateS<<16 + } + d.bestTableShort = &sTable + d.bestTableLong = &lTable + }) +} diff --git a/s2/dict_test.go b/s2/dict_test.go new file mode 100644 index 0000000000..c0bab2d87e --- /dev/null +++ b/s2/dict_test.go @@ -0,0 +1,493 @@ +// Copyright (c) 2023+ Klaus Post. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package s2 + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "fmt" + "io" + "math/rand" + "os" + "testing" + + "github.com/klauspost/compress/internal/fuzz" + "github.com/klauspost/compress/zstd" +) + +func TestDict(t *testing.T) { + rng := rand.New(rand.NewSource(1)) + data := make([]byte, 128<<10) + for i := range data { + data[i] = uint8(rng.Intn(256)) + } + + // Should match the first 64K + d := NewDict(append([]byte{0}, data[:65536]...)) + encoded := make([]byte, MaxEncodedLen(len(data))) + res := encodeBlockDictGo(encoded, data, d) + if res == 0 || res > len(data)-65500 { + t.Errorf("did no get expected dict saving. Saved %d bytes", len(data)-res) + } + encoded = encoded[:res] + t.Log("saved", len(data)-res, "bytes") + decoded := make([]byte, len(data)) + res = s2DecodeDict(decoded, encoded, d) + if res != 0 { + t.Fatalf("got result: %d", res) + } + if !bytes.Equal(decoded, data) { + //os.WriteFile("decoded.bin", decoded, os.ModePerm) + //os.WriteFile("original.bin", data, os.ModePerm) + t.Fatal("decoded mismatch") + } + + // Add dict that will produce a full match 5000 chars into the input. + d = NewDict(append([]byte{0}, data[5000:65536+5000]...)) + encoded = make([]byte, MaxEncodedLen(len(data))) + res = encodeBlockDictGo(encoded, data, d) + if res == 0 || res > len(data)-65500 { + t.Errorf("did no get expected dict saving. Saved %d bytes", len(data)-res) + } + encoded = encoded[:res] + t.Log("saved", len(data)-res, "bytes") + decoded = make([]byte, len(data)) + res = s2DecodeDict(decoded, encoded, d) + if res != 0 { + t.Fatalf("got result: %d", res) + } + if !bytes.Equal(decoded, data) { + //os.WriteFile("decoded.bin", decoded, os.ModePerm) + //os.WriteFile("original.bin", data, os.ModePerm) + t.Fatal("decoded mismatch") + } + + // generate copies + for i := 1; i < len(data); { + n := rng.Intn(32) + 4 + off := rng.Intn(len(data) - n) + copy(data[i:], data[off:off+n]) + i += n + } + + dict := make([]byte, 65536) + for i := 1; i < len(dict); { + n := rng.Intn(32) + 4 + off := rng.Intn(65536 - n) + copy(dict[i:], data[off:off+n]) + i += n + } + d = NewDict(dict) + encoded = make([]byte, MaxEncodedLen(len(data))) + res = encodeBlockDictGo(encoded, data, d) + if res == 0 || res > len(data)-20000 { + t.Errorf("did no get expected dict saving. Saved %d bytes", len(data)-res) + } + encoded = encoded[:res] + t.Log("saved", len(data)-res, "bytes") + decoded = make([]byte, len(data)) + res = s2DecodeDict(decoded, encoded, d) + if res != 0 { + t.Fatalf("got result: %d", res) + } + if !bytes.Equal(decoded, data) { + os.WriteFile("decoded.bin", decoded, os.ModePerm) + os.WriteFile("original.bin", data, os.ModePerm) + t.Fatal("decoded mismatch") + } +} + +func TestDictBetter(t *testing.T) { + rng := rand.New(rand.NewSource(1)) + data := make([]byte, 128<<10) + for i := range data { + data[i] = uint8(rng.Intn(256)) + } + + // Should match the first 64K + d := NewDict(append([]byte{0}, data[:65536]...)) + encoded := make([]byte, MaxEncodedLen(len(data))) + res := encodeBlockBetterDict(encoded, data, d) + if res == 0 || res > len(data)-65500 { + t.Errorf("did no get expected dict saving. Saved %d bytes", len(data)-res) + } + encoded = encoded[:res] + t.Log("saved", len(data)-res, "bytes") + decoded := make([]byte, len(data)) + res = s2DecodeDict(decoded, encoded, d) + if res != 0 { + t.Fatalf("got result: %d", res) + } + if !bytes.Equal(decoded, data) { + //os.WriteFile("decoded.bin", decoded, os.ModePerm) + //os.WriteFile("original.bin", data, os.ModePerm) + t.Fatal("decoded mismatch") + } + + // Add dict that will produce a full match 5000 chars into the input. + d = NewDict(append([]byte{0}, data[5000:65536+5000]...)) + encoded = make([]byte, MaxEncodedLen(len(data))) + res = encodeBlockBetterDict(encoded, data, d) + if res == 0 || res > len(data)-65500 { + t.Errorf("did no get expected dict saving. Saved %d bytes", len(data)-res) + } + encoded = encoded[:res] + t.Log("saved", len(data)-res, "bytes") + decoded = make([]byte, len(data)) + res = s2DecodeDict(decoded, encoded, d) + if res != 0 { + t.Fatalf("got result: %d", res) + } + if !bytes.Equal(decoded, data) { + //os.WriteFile("decoded.bin", decoded, os.ModePerm) + //os.WriteFile("original.bin", data, os.ModePerm) + t.Fatal("decoded mismatch") + } + + // generate copies + for i := 1; i < len(data); { + n := rng.Intn(32) + 4 + off := rng.Intn(len(data) - n) + copy(data[i:], data[off:off+n]) + i += n + } + + dict := make([]byte, 65536) + for i := 1; i < len(dict); { + n := rng.Intn(32) + 4 + off := rng.Intn(65536 - n) + copy(dict[i:], data[off:off+n]) + i += n + } + d = NewDict(dict) + encoded = make([]byte, MaxEncodedLen(len(data))) + res = encodeBlockBetterDict(encoded, data, d) + if res == 0 || res > len(data)-20000 { + t.Errorf("did no get expected dict saving. Saved %d bytes", len(data)-res) + } + encoded = encoded[:res] + t.Log("saved", len(data)-res, "bytes") + decoded = make([]byte, len(data)) + res = s2DecodeDict(decoded, encoded, d) + if res != 0 { + t.Fatalf("got result: %d", res) + } + if !bytes.Equal(decoded, data) { + os.WriteFile("decoded.bin", decoded, os.ModePerm) + os.WriteFile("original.bin", data, os.ModePerm) + t.Fatal("decoded mismatch") + } +} + +func TestDictBest(t *testing.T) { + rng := rand.New(rand.NewSource(1)) + data := make([]byte, 128<<10) + for i := range data { + data[i] = uint8(rng.Intn(256)) + } + + // Should match the first 64K + d := NewDict(append([]byte{0}, data[:65536]...)) + encoded := make([]byte, MaxEncodedLen(len(data))) + res := encodeBlockBest(encoded, data, d) + if res == 0 || res > len(data)-65500 { + t.Errorf("did no get expected dict saving. Saved %d bytes", len(data)-res) + } + encoded = encoded[:res] + t.Log("saved", len(data)-res, "bytes") + decoded := make([]byte, len(data)) + res = s2DecodeDict(decoded, encoded, d) + if res != 0 { + t.Fatalf("got result: %d", res) + } + if !bytes.Equal(decoded, data) { + //os.WriteFile("decoded.bin", decoded, os.ModePerm) + //os.WriteFile("original.bin", data, os.ModePerm) + t.Fatal("decoded mismatch") + } + + // Add dict that will produce a full match 5000 chars into the input. + d = NewDict(append([]byte{0}, data[5000:65536+5000]...)) + encoded = make([]byte, MaxEncodedLen(len(data))) + res = encodeBlockBest(encoded, data, d) + if res == 0 || res > len(data)-65500 { + t.Errorf("did no get expected dict saving. Saved %d bytes", len(data)-res) + } + encoded = encoded[:res] + t.Log("saved", len(data)-res, "bytes") + decoded = make([]byte, len(data)) + res = s2DecodeDict(decoded, encoded, d) + if res != 0 { + t.Fatalf("got result: %d", res) + } + if !bytes.Equal(decoded, data) { + //os.WriteFile("decoded.bin", decoded, os.ModePerm) + //os.WriteFile("original.bin", data, os.ModePerm) + t.Fatal("decoded mismatch") + } + + // generate copies + for i := 1; i < len(data); { + n := rng.Intn(32) + 4 + off := rng.Intn(len(data) - n) + copy(data[i:], data[off:off+n]) + i += n + } + + dict := make([]byte, 65536) + for i := 1; i < len(dict); { + n := rng.Intn(32) + 4 + off := rng.Intn(65536 - n) + copy(dict[i:], data[off:off+n]) + i += n + } + d = NewDict(dict) + encoded = make([]byte, MaxEncodedLen(len(data))) + res = encodeBlockBest(encoded, data, d) + if res == 0 || res > len(data)-20000 { + t.Errorf("did no get expected dict saving. Saved %d bytes", len(data)-res) + } + encoded = encoded[:res] + t.Log("saved", len(data)-res, "bytes") + decoded = make([]byte, len(data)) + res = s2DecodeDict(decoded, encoded, d) + if res != 0 { + t.Fatalf("got result: %d", res) + } + if !bytes.Equal(decoded, data) { + os.WriteFile("decoded.bin", decoded, os.ModePerm) + os.WriteFile("original.bin", data, os.ModePerm) + t.Fatal("decoded mismatch") + } +} + +func TestDictBetter2(t *testing.T) { + // Should match the first 64K + data := []byte("10 bananas which were brown were added") + d := NewDict(append([]byte{6}, []byte("Yesterday 25 bananas were added to Benjamins brown bag")...)) + encoded := make([]byte, MaxEncodedLen(len(data))) + res := encodeBlockBetterDict(encoded, data, d) + encoded = encoded[:res] + t.Log("saved", len(data)-res, "bytes") + t.Log(string(encoded)) + decoded := make([]byte, len(data)) + res = s2DecodeDict(decoded, encoded, d) + if res != 0 { + t.Fatalf("got result: %d", res) + } + if !bytes.Equal(decoded, data) { + //os.WriteFile("decoded.bin", decoded, os.ModePerm) + //os.WriteFile("original.bin", data, os.ModePerm) + t.Fatal("decoded mismatch") + } +} + +func TestDictBest2(t *testing.T) { + // Should match the first 64K + data := []byte("10 bananas which were brown were added") + d := NewDict(append([]byte{6}, []byte("Yesterday 25 bananas were added to Benjamins brown bag")...)) + encoded := make([]byte, MaxEncodedLen(len(data))) + res := encodeBlockBest(encoded, data, d) + encoded = encoded[:res] + t.Log("saved", len(data)-res, "bytes") + t.Log(string(encoded)) + decoded := make([]byte, len(data)) + res = s2DecodeDict(decoded, encoded, d) + if res != 0 { + t.Fatalf("got result: %d", res) + } + if !bytes.Equal(decoded, data) { + //os.WriteFile("decoded.bin", decoded, os.ModePerm) + //os.WriteFile("original.bin", data, os.ModePerm) + t.Fatal("decoded mismatch") + } +} + +func TestDictSize(t *testing.T) { + //f, err := os.Open("testdata/xlmeta.tar.s2") + //f, err := os.Open("testdata/broken.tar.s2") + f, err := os.Open("testdata/github_users_sample_set.tar.s2") + //f, err := os.Open("testdata/gofiles2.tar.s2") + //f, err := os.Open("testdata/gosrc.tar.s2") + if err != nil { + t.Skip(err) + } + stream := NewReader(f) + in := tar.NewReader(stream) + //rawDict, err := os.ReadFile("testdata/godict.dictator") + rawDict, err := os.ReadFile("testdata/gofiles.dict") + //rawDict, err := os.ReadFile("testdata/gosrc2.dict") + //rawDict, err := os.ReadFile("testdata/td.dict") + //rawDict, err := os.ReadFile("testdata/users.dict") + //rawDict, err := os.ReadFile("testdata/xlmeta.dict") + if err != nil { + t.Fatal(err) + } + + lidx := -1 + if di, err := zstd.InspectDictionary(rawDict); err == nil { + rawDict = di.Content() + lidx = len(rawDict) - di.Offsets()[0] + } else { + t.Errorf("Loading dict: %v", err) + return + } + + searchFor := "" + if false { + searchFor = "// Copyright 2022" + } + d := MakeDict(rawDict, []byte(searchFor)) + if d == nil { + t.Fatal("no dict", lidx) + } + + var totalIn int + var totalOut int + var totalCount int + for { + h, err := in.Next() + if err != nil { + break + } + if h.Size == 0 { + continue + } + data := make([]byte, 65536) + t.Run(h.Name, func(t *testing.T) { + if int(h.Size) < 65536 { + data = data[:h.Size] + } else { + data = data[:65536] + } + _, err := io.ReadFull(in, data) + if err != nil { + t.Skip() + } + if d == nil { + // Use first file as dict + d = MakeDict(data, nil) + } + // encode + encoded := make([]byte, MaxEncodedLen(len(data))) + totalIn += len(data) + totalCount++ + //res := encodeBlockBest(encoded, data, nil) + res := encodeBlockBest(encoded, data, d) + //res := encodeBlockBetterDict(encoded, data, d) + //res := encodeBlockBetterGo(encoded, data) + //res := encodeBlockDictGo(encoded, data, d) + // res := encodeBlockGo(encoded, data) + if res == 0 { + totalOut += len(data) + return + } + totalOut += res + encoded = encoded[:res] + //t.Log("encoded", len(data), "->", res, "saved", len(data)-res, "bytes") + decoded := make([]byte, len(data)) + res = s2DecodeDict(decoded, encoded, d) + if res != 0 { + t.Fatalf("got result: %d", res) + } + if !bytes.Equal(decoded, data) { + os.WriteFile("decoded.bin", decoded, os.ModePerm) + os.WriteFile("original.bin", data, os.ModePerm) + t.Fatal("decoded mismatch") + } + }) + } + fmt.Printf("%d files, %d -> %d (%.2f%%) - %.02f bytes saved/file\n", totalCount, totalIn, totalOut, float64(totalOut*100)/float64(totalIn), float64(totalIn-totalOut)/float64(totalCount)) +} + +func FuzzDictBlocks(f *testing.F) { + fuzz.AddFromZip(f, "testdata/enc_regressions.zip", true, false) + fuzz.AddFromZip(f, "testdata/fuzz/block-corpus-raw.zip", true, testing.Short()) + fuzz.AddFromZip(f, "testdata/fuzz/block-corpus-enc.zip", false, testing.Short()) + + // Fuzzing tweaks: + const ( + // Max input size: + maxSize = 8 << 20 + ) + file, err := os.Open("testdata/s2-dict.bin.gz") + if err != nil { + f.Fatal(err) + } + gzr, err := gzip.NewReader(file) + if err != nil { + f.Fatal(err) + } + dictBytes, err := io.ReadAll(gzr) + if err != nil { + f.Fatal(err) + } + dict := NewDict(dictBytes) + if dict == nil { + f.Fatal("invalid dict") + } + + f.Fuzz(func(t *testing.T, data []byte) { + if len(data) > maxSize { + return + } + + writeDst := make([]byte, MaxEncodedLen(len(data)), MaxEncodedLen(len(data))+4) + writeDst = append(writeDst, 1, 2, 3, 4) + defer func() { + got := writeDst[MaxEncodedLen(len(data)):] + want := []byte{1, 2, 3, 4} + if !bytes.Equal(got, want) { + t.Fatalf("want %v, got %v - dest modified outside cap", want, got) + } + }() + compDst := writeDst[:MaxEncodedLen(len(data)):MaxEncodedLen(len(data))] // Hard cap + decDst := make([]byte, len(data)) + comp := dict.Encode(compDst, data) + decoded, err := dict.Decode(decDst, comp) + if err != nil { + t.Error(err) + return + } + if !bytes.Equal(data, decoded) { + t.Error("block decoder mismatch") + return + } + if mel := MaxEncodedLen(len(data)); len(comp) > mel { + t.Error(fmt.Errorf("MaxEncodedLen Exceed: input: %d, mel: %d, got %d", len(data), mel, len(comp))) + return + } + comp = dict.EncodeBetter(compDst, data) + decoded, err = dict.Decode(decDst, comp) + if err != nil { + t.Error(err) + return + } + if !bytes.Equal(data, decoded) { + t.Error("block decoder mismatch") + return + } + if mel := MaxEncodedLen(len(data)); len(comp) > mel { + t.Error(fmt.Errorf("MaxEncodedLen Exceed: input: %d, mel: %d, got %d", len(data), mel, len(comp))) + return + } + + comp = dict.EncodeBest(compDst, data) + decoded, err = dict.Decode(decDst, comp) + if err != nil { + t.Error(err) + return + } + if !bytes.Equal(data, decoded) { + t.Error("block decoder mismatch") + return + } + if mel := MaxEncodedLen(len(data)); len(comp) > mel { + t.Error(fmt.Errorf("MaxEncodedLen Exceed: input: %d, mel: %d, got %d", len(data), mel, len(comp))) + return + } + }) +} diff --git a/s2/encode.go b/s2/encode.go index 59edf1a576..c2ca7236a1 100644 --- a/s2/encode.go +++ b/s2/encode.go @@ -158,7 +158,7 @@ func EncodeBest(dst, src []byte) []byte { d += emitLiteral(dst[d:], src) return dst[:d] } - n := encodeBlockBest(dst[d:], src) + n := encodeBlockBest(dst[d:], src, nil) if n > 0 { d += n return dst[:d] @@ -820,7 +820,7 @@ func (w *Writer) encodeBlock(obuf, uncompressed []byte) int { case levelBetter: return encodeBlockBetter(obuf, uncompressed) case levelBest: - return encodeBlockBest(obuf, uncompressed) + return encodeBlockBest(obuf, uncompressed, nil) } return 0 } diff --git a/s2/encode_all.go b/s2/encode_all.go index 54c71d3b5d..11657f0949 100644 --- a/s2/encode_all.go +++ b/s2/encode_all.go @@ -8,6 +8,7 @@ package s2 import ( "bytes" "encoding/binary" + "fmt" "math/bits" ) @@ -455,3 +456,594 @@ emitRemainder: } return d } + +// encodeBlockGo encodes a non-empty src to a guaranteed-large-enough dst. It +// assumes that the varint-encoded length of the decompressed bytes has already +// been written. +// +// It also assumes that: +// +// len(dst) >= MaxEncodedLen(len(src)) && +// minNonLiteralBlockSize <= len(src) && len(src) <= maxBlockSize +func encodeBlockDictGo(dst, src []byte, dict *Dict) (d int) { + // Initialize the hash table. + const ( + tableBits = 14 + maxTableSize = 1 << tableBits + maxAhead = 8 // maximum bytes ahead without checking sLimit + + debug = false + ) + dict.initFast() + + var table [maxTableSize]uint32 + + // sLimit is when to stop looking for offset/length copies. The inputMargin + // lets us use a fast path for emitLiteral in the main loop, while we are + // looking for copies. + sLimit := len(src) - inputMargin + if sLimit > MaxDictSrcOffset-maxAhead { + sLimit = MaxDictSrcOffset - maxAhead + } + + // Bail if we can't compress to at least this. + dstLimit := len(src) - len(src)>>5 - 5 + + // nextEmit is where in src the next emitLiteral should start from. + nextEmit := 0 + + // The encoded form can start with a dict entry (copy or repeat). + s := 0 + + // Convert dict repeat to offset + repeat := len(dict.dict) - dict.repeat + cv := load64(src, 0) + + // While in dict +searchDict: + for { + // Next src position to check + nextS := s + (s-nextEmit)>>6 + 4 + hash0 := hash6(cv, tableBits) + hash1 := hash6(cv>>8, tableBits) + if nextS > sLimit { + if debug { + fmt.Println("slimit reached", s, nextS) + } + break searchDict + } + candidateDict := int(dict.fastTable[hash0]) + candidateDict2 := int(dict.fastTable[hash1]) + candidate2 := int(table[hash1]) + candidate := int(table[hash0]) + table[hash0] = uint32(s) + table[hash1] = uint32(s + 1) + hash2 := hash6(cv>>16, tableBits) + + // Check repeat at offset checkRep. + const checkRep = 1 + + if repeat > s { + candidate := len(dict.dict) - repeat + s + if repeat-s >= 4 && uint32(cv) == load32(dict.dict, candidate) { + // Extend back + base := s + for i := candidate; base > nextEmit && i > 0 && dict.dict[i-1] == src[base-1]; { + i-- + base-- + } + d += emitLiteral(dst[d:], src[nextEmit:base]) + if debug && nextEmit != base { + fmt.Println("emitted ", base-nextEmit, "literals") + } + s += 4 + candidate += 4 + for candidate < len(dict.dict)-8 && s <= len(src)-8 { + if diff := load64(src, s) ^ load64(dict.dict, candidate); diff != 0 { + s += bits.TrailingZeros64(diff) >> 3 + break + } + s += 8 + candidate += 8 + } + d += emitRepeat(dst[d:], repeat, s-base) + if debug { + fmt.Println("emitted dict repeat length", s-base, "offset:", repeat, "s:", s) + } + nextEmit = s + if s >= sLimit { + break searchDict + } + cv = load64(src, s) + continue + } + } else if uint32(cv>>(checkRep*8)) == load32(src, s-repeat+checkRep) { + base := s + checkRep + // Extend back + for i := base - repeat; base > nextEmit && i > 0 && src[i-1] == src[base-1]; { + i-- + base-- + } + d += emitLiteral(dst[d:], src[nextEmit:base]) + if debug && nextEmit != base { + fmt.Println("emitted ", base-nextEmit, "literals") + } + + // Extend forward + candidate := s - repeat + 4 + checkRep + s += 4 + checkRep + for s <= sLimit { + if diff := load64(src, s) ^ load64(src, candidate); diff != 0 { + s += bits.TrailingZeros64(diff) >> 3 + break + } + s += 8 + candidate += 8 + } + if debug { + // Validate match. + if s <= candidate { + panic("s <= candidate") + } + a := src[base:s] + b := src[base-repeat : base-repeat+(s-base)] + if !bytes.Equal(a, b) { + panic("mismatch") + } + } + + if nextEmit > 0 { + // same as `add := emitCopy(dst[d:], repeat, s-base)` but skips storing offset. + d += emitRepeat(dst[d:], repeat, s-base) + } else { + // First match, cannot be repeat. + d += emitCopy(dst[d:], repeat, s-base) + } + + nextEmit = s + if s >= sLimit { + break searchDict + } + if debug { + fmt.Println("emitted reg repeat", s-base, "s:", s) + } + cv = load64(src, s) + continue searchDict + } + if s == 0 { + cv = load64(src, nextS) + s = nextS + continue searchDict + } + // Start with table. These matches will always be closer. + if uint32(cv) == load32(src, candidate) { + goto emitMatch + } + candidate = int(table[hash2]) + if uint32(cv>>8) == load32(src, candidate2) { + table[hash2] = uint32(s + 2) + candidate = candidate2 + s++ + goto emitMatch + } + + // Check dict. Dicts have longer offsets, so we want longer matches. + if cv == load64(dict.dict, candidateDict) { + table[hash2] = uint32(s + 2) + goto emitDict + } + + candidateDict = int(dict.fastTable[hash2]) + // Check if upper 7 bytes match + if candidateDict2 >= 1 { + if cv^load64(dict.dict, candidateDict2-1) < (1 << 8) { + table[hash2] = uint32(s + 2) + candidateDict = candidateDict2 + s++ + goto emitDict + } + } + + table[hash2] = uint32(s + 2) + if uint32(cv>>16) == load32(src, candidate) { + s += 2 + goto emitMatch + } + if candidateDict >= 2 { + // Check if upper 6 bytes match + if cv^load64(dict.dict, candidateDict-2) < (1 << 16) { + s += 2 + goto emitDict + } + } + + cv = load64(src, nextS) + s = nextS + continue searchDict + + emitDict: + { + if debug { + if load32(dict.dict, candidateDict) != load32(src, s) { + panic("dict emit mismatch") + } + } + // Extend backwards. + // The top bytes will be rechecked to get the full match. + for candidateDict > 0 && s > nextEmit && dict.dict[candidateDict-1] == src[s-1] { + candidateDict-- + s-- + } + + // Bail if we exceed the maximum size. + if d+(s-nextEmit) > dstLimit { + return 0 + } + + // A 4-byte match has been found. We'll later see if more than 4 bytes + // match. But, prior to the match, src[nextEmit:s] are unmatched. Emit + // them as literal bytes. + + d += emitLiteral(dst[d:], src[nextEmit:s]) + if debug && nextEmit != s { + fmt.Println("emitted ", s-nextEmit, "literals") + } + { + // Invariant: we have a 4-byte match at s, and no need to emit any + // literal bytes prior to s. + base := s + repeat = s + (len(dict.dict)) - candidateDict + + // Extend the 4-byte match as long as possible. + s += 4 + candidateDict += 4 + for s <= len(src)-8 && len(dict.dict)-candidateDict >= 8 { + if diff := load64(src, s) ^ load64(dict.dict, candidateDict); diff != 0 { + s += bits.TrailingZeros64(diff) >> 3 + break + } + s += 8 + candidateDict += 8 + } + + // Matches longer than 64 are split. + if s <= sLimit || s-base < 8 { + d += emitCopy(dst[d:], repeat, s-base) + } else { + // Split to ensure we don't start a copy within next block + d += emitCopy(dst[d:], repeat, 4) + d += emitRepeat(dst[d:], repeat, s-base-4) + } + if false { + // Validate match. + if s <= candidate { + panic("s <= candidate") + } + a := src[base:s] + b := dict.dict[base-repeat : base-repeat+(s-base)] + if !bytes.Equal(a, b) { + panic("mismatch") + } + } + if debug { + fmt.Println("emitted dict copy, length", s-base, "offset:", repeat, "s:", s) + } + nextEmit = s + if s >= sLimit { + break searchDict + } + + if d > dstLimit { + // Do we have space for more, if not bail. + return 0 + } + + // Index and continue loop to try new candidate. + x := load64(src, s-2) + m2Hash := hash6(x, tableBits) + currHash := hash6(x>>8, tableBits) + candidate = int(table[currHash]) + table[m2Hash] = uint32(s - 2) + table[currHash] = uint32(s - 1) + cv = load64(src, s) + } + continue + } + emitMatch: + + // Extend backwards. + // The top bytes will be rechecked to get the full match. + for candidate > 0 && s > nextEmit && src[candidate-1] == src[s-1] { + candidate-- + s-- + } + + // Bail if we exceed the maximum size. + if d+(s-nextEmit) > dstLimit { + return 0 + } + + // A 4-byte match has been found. We'll later see if more than 4 bytes + // match. But, prior to the match, src[nextEmit:s] are unmatched. Emit + // them as literal bytes. + + d += emitLiteral(dst[d:], src[nextEmit:s]) + if debug && nextEmit != s { + fmt.Println("emitted ", s-nextEmit, "literals") + } + // Call emitCopy, and then see if another emitCopy could be our next + // move. Repeat until we find no match for the input immediately after + // what was consumed by the last emitCopy call. + // + // If we exit this loop normally then we need to call emitLiteral next, + // though we don't yet know how big the literal will be. We handle that + // by proceeding to the next iteration of the main loop. We also can + // exit this loop via goto if we get close to exhausting the input. + for { + // Invariant: we have a 4-byte match at s, and no need to emit any + // literal bytes prior to s. + base := s + repeat = base - candidate + + // Extend the 4-byte match as long as possible. + s += 4 + candidate += 4 + for s <= len(src)-8 { + if diff := load64(src, s) ^ load64(src, candidate); diff != 0 { + s += bits.TrailingZeros64(diff) >> 3 + break + } + s += 8 + candidate += 8 + } + + d += emitCopy(dst[d:], repeat, s-base) + if debug { + // Validate match. + if s <= candidate { + panic("s <= candidate") + } + a := src[base:s] + b := src[base-repeat : base-repeat+(s-base)] + if !bytes.Equal(a, b) { + panic("mismatch") + } + } + if debug { + fmt.Println("emitted src copy, length", s-base, "offset:", repeat, "s:", s) + } + nextEmit = s + if s >= sLimit { + break searchDict + } + + if d > dstLimit { + // Do we have space for more, if not bail. + return 0 + } + // Check for an immediate match, otherwise start search at s+1 + x := load64(src, s-2) + m2Hash := hash6(x, tableBits) + currHash := hash6(x>>16, tableBits) + candidate = int(table[currHash]) + table[m2Hash] = uint32(s - 2) + table[currHash] = uint32(s) + if debug && s == candidate { + panic("s == candidate") + } + if uint32(x>>16) != load32(src, candidate) { + cv = load64(src, s+1) + s++ + break + } + } + } + + // Search without dict: + if repeat > s { + repeat = 0 + } + + // No more dict + sLimit = len(src) - inputMargin + if s >= sLimit { + goto emitRemainder + } + if debug { + fmt.Println("non-dict matching at", s, "repeat:", repeat) + } + cv = load64(src, s) + if debug { + fmt.Println("now", s, "->", sLimit, "out:", d, "left:", len(src)-s, "nextemit:", nextEmit, "dstLimit:", dstLimit, "s:", s) + } + for { + candidate := 0 + for { + // Next src position to check + nextS := s + (s-nextEmit)>>6 + 4 + if nextS > sLimit { + goto emitRemainder + } + hash0 := hash6(cv, tableBits) + hash1 := hash6(cv>>8, tableBits) + candidate = int(table[hash0]) + candidate2 := int(table[hash1]) + table[hash0] = uint32(s) + table[hash1] = uint32(s + 1) + hash2 := hash6(cv>>16, tableBits) + + // Check repeat at offset checkRep. + const checkRep = 1 + if repeat > 0 && uint32(cv>>(checkRep*8)) == load32(src, s-repeat+checkRep) { + base := s + checkRep + // Extend back + for i := base - repeat; base > nextEmit && i > 0 && src[i-1] == src[base-1]; { + i-- + base-- + } + d += emitLiteral(dst[d:], src[nextEmit:base]) + if debug && nextEmit != base { + fmt.Println("emitted ", base-nextEmit, "literals") + } + // Extend forward + candidate := s - repeat + 4 + checkRep + s += 4 + checkRep + for s <= sLimit { + if diff := load64(src, s) ^ load64(src, candidate); diff != 0 { + s += bits.TrailingZeros64(diff) >> 3 + break + } + s += 8 + candidate += 8 + } + if debug { + // Validate match. + if s <= candidate { + panic("s <= candidate") + } + a := src[base:s] + b := src[base-repeat : base-repeat+(s-base)] + if !bytes.Equal(a, b) { + panic("mismatch") + } + } + if nextEmit > 0 { + // same as `add := emitCopy(dst[d:], repeat, s-base)` but skips storing offset. + d += emitRepeat(dst[d:], repeat, s-base) + } else { + // First match, cannot be repeat. + d += emitCopy(dst[d:], repeat, s-base) + } + if debug { + fmt.Println("emitted src repeat length", s-base, "offset:", repeat, "s:", s) + } + nextEmit = s + if s >= sLimit { + goto emitRemainder + } + + cv = load64(src, s) + continue + } + + if uint32(cv) == load32(src, candidate) { + break + } + candidate = int(table[hash2]) + if uint32(cv>>8) == load32(src, candidate2) { + table[hash2] = uint32(s + 2) + candidate = candidate2 + s++ + break + } + table[hash2] = uint32(s + 2) + if uint32(cv>>16) == load32(src, candidate) { + s += 2 + break + } + + cv = load64(src, nextS) + s = nextS + } + + // Extend backwards. + // The top bytes will be rechecked to get the full match. + for candidate > 0 && s > nextEmit && src[candidate-1] == src[s-1] { + candidate-- + s-- + } + + // Bail if we exceed the maximum size. + if d+(s-nextEmit) > dstLimit { + return 0 + } + + // A 4-byte match has been found. We'll later see if more than 4 bytes + // match. But, prior to the match, src[nextEmit:s] are unmatched. Emit + // them as literal bytes. + + d += emitLiteral(dst[d:], src[nextEmit:s]) + if debug && nextEmit != s { + fmt.Println("emitted ", s-nextEmit, "literals") + } + // Call emitCopy, and then see if another emitCopy could be our next + // move. Repeat until we find no match for the input immediately after + // what was consumed by the last emitCopy call. + // + // If we exit this loop normally then we need to call emitLiteral next, + // though we don't yet know how big the literal will be. We handle that + // by proceeding to the next iteration of the main loop. We also can + // exit this loop via goto if we get close to exhausting the input. + for { + // Invariant: we have a 4-byte match at s, and no need to emit any + // literal bytes prior to s. + base := s + repeat = base - candidate + + // Extend the 4-byte match as long as possible. + s += 4 + candidate += 4 + for s <= len(src)-8 { + if diff := load64(src, s) ^ load64(src, candidate); diff != 0 { + s += bits.TrailingZeros64(diff) >> 3 + break + } + s += 8 + candidate += 8 + } + + d += emitCopy(dst[d:], repeat, s-base) + if debug { + // Validate match. + if s <= candidate { + panic("s <= candidate") + } + a := src[base:s] + b := src[base-repeat : base-repeat+(s-base)] + if !bytes.Equal(a, b) { + panic("mismatch") + } + } + if debug { + fmt.Println("emitted src copy, length", s-base, "offset:", repeat, "s:", s) + } + nextEmit = s + if s >= sLimit { + goto emitRemainder + } + + if d > dstLimit { + // Do we have space for more, if not bail. + return 0 + } + // Check for an immediate match, otherwise start search at s+1 + x := load64(src, s-2) + m2Hash := hash6(x, tableBits) + currHash := hash6(x>>16, tableBits) + candidate = int(table[currHash]) + table[m2Hash] = uint32(s - 2) + table[currHash] = uint32(s) + if debug && s == candidate { + panic("s == candidate") + } + if uint32(x>>16) != load32(src, candidate) { + cv = load64(src, s+1) + s++ + break + } + } + } + +emitRemainder: + if nextEmit < len(src) { + // Bail if we exceed the maximum size. + if d+len(src)-nextEmit > dstLimit { + return 0 + } + d += emitLiteral(dst[d:], src[nextEmit:]) + if debug && nextEmit != s { + fmt.Println("emitted ", len(src)-nextEmit, "literals") + } + } + return d +} diff --git a/s2/encode_best.go b/s2/encode_best.go index 1b7ea394fa..1d13e869a1 100644 --- a/s2/encode_best.go +++ b/s2/encode_best.go @@ -7,6 +7,7 @@ package s2 import ( "fmt" + "math" "math/bits" ) @@ -18,7 +19,7 @@ import ( // // len(dst) >= MaxEncodedLen(len(src)) && // minNonLiteralBlockSize <= len(src) && len(src) <= maxBlockSize -func encodeBlockBest(dst, src []byte) (d int) { +func encodeBlockBest(dst, src []byte, dict *Dict) (d int) { // Initialize the hash tables. const ( // Long hash matches. @@ -30,6 +31,8 @@ func encodeBlockBest(dst, src []byte) (d int) { maxSTableSize = 1 << sTableBits inputMargin = 8 + 2 + + debug = false ) // sLimit is when to stop looking for offset/length copies. The inputMargin @@ -39,6 +42,10 @@ func encodeBlockBest(dst, src []byte) (d int) { if len(src) < minNonLiteralBlockSize { return 0 } + sLimitDict := len(src) - inputMargin + if sLimitDict > MaxDictSrcOffset-inputMargin { + sLimitDict = MaxDictSrcOffset - inputMargin + } var lTable [maxLTableSize]uint64 var sTable [maxSTableSize]uint64 @@ -52,10 +59,15 @@ func encodeBlockBest(dst, src []byte) (d int) { // The encoded form must start with a literal, as there are no previous // bytes to copy, so we start looking for hash matches at s == 1. s := 1 + repeat := 1 + if dict != nil { + dict.initBest() + s = 0 + repeat = len(dict.dict) - dict.repeat + } cv := load64(src, s) // We search for a repeat at -1, but don't output repeats when nextEmit == 0 - repeat := 1 const lowbitMask = 0xffffffff getCur := func(x uint64) int { return int(x & lowbitMask) @@ -67,11 +79,11 @@ func encodeBlockBest(dst, src []byte) (d int) { for { type match struct { - offset int - s int - length int - score int - rep bool + offset int + s int + length int + score int + rep, dict bool } var best match for { @@ -85,6 +97,12 @@ func encodeBlockBest(dst, src []byte) (d int) { if nextS > sLimit { goto emitRemainder } + if dict != nil && s >= MaxDictSrcOffset { + dict = nil + if repeat > s { + repeat = math.MinInt32 + } + } hashL := hash8(cv, lTableBits) hashS := hash4(cv, sTableBits) candidateL := lTable[hashL] @@ -114,7 +132,15 @@ func encodeBlockBest(dst, src []byte) (d int) { } m := match{offset: offset, s: s, length: 4 + offset, rep: rep} s += 4 - for s <= sLimit { + for s < len(src) { + if len(src)-s < 8 { + if src[s] == src[m.length] { + m.length++ + s++ + continue + } + break + } if diff := load64(src, s) ^ load64(src, m.length); diff != 0 { m.length += bits.TrailingZeros64(diff) >> 3 break @@ -130,6 +156,62 @@ func encodeBlockBest(dst, src []byte) (d int) { } return m } + matchDict := func(candidate, s int, first uint32, rep bool) match { + // Calculate offset as if in continuous array with s + offset := -len(dict.dict) + candidate + if best.length != 0 && best.s-best.offset == s-offset && !rep { + // Don't retest if we have the same offset. + return match{offset: offset, s: s} + } + + if load32(dict.dict, candidate) != first { + return match{offset: offset, s: s} + } + m := match{offset: offset, s: s, length: 4 + candidate, rep: rep, dict: true} + s += 4 + if !rep { + for s < sLimitDict && m.length < len(dict.dict) { + if len(src)-s < 8 || len(dict.dict)-m.length < 8 { + if src[s] == dict.dict[m.length] { + m.length++ + s++ + continue + } + break + } + if diff := load64(src, s) ^ load64(dict.dict, m.length); diff != 0 { + m.length += bits.TrailingZeros64(diff) >> 3 + break + } + s += 8 + m.length += 8 + } + } else { + for s < len(src) && m.length < len(dict.dict) { + if len(src)-s < 8 || len(dict.dict)-m.length < 8 { + if src[s] == dict.dict[m.length] { + m.length++ + s++ + continue + } + break + } + if diff := load64(src, s) ^ load64(dict.dict, m.length); diff != 0 { + m.length += bits.TrailingZeros64(diff) >> 3 + break + } + s += 8 + m.length += 8 + } + } + m.length -= candidate + m.score = score(m) + if m.score <= -m.s { + // Eliminate if no savings, we might find a better one. + m.length = 0 + } + return m + } bestOf := func(a, b match) match { if b.length == 0 { @@ -146,35 +228,82 @@ func encodeBlockBest(dst, src []byte) (d int) { return b } - best = bestOf(matchAt(getCur(candidateL), s, uint32(cv), false), matchAt(getPrev(candidateL), s, uint32(cv), false)) - best = bestOf(best, matchAt(getCur(candidateS), s, uint32(cv), false)) - best = bestOf(best, matchAt(getPrev(candidateS), s, uint32(cv), false)) - + if s > 0 { + best = bestOf(matchAt(getCur(candidateL), s, uint32(cv), false), matchAt(getPrev(candidateL), s, uint32(cv), false)) + best = bestOf(best, matchAt(getCur(candidateS), s, uint32(cv), false)) + best = bestOf(best, matchAt(getPrev(candidateS), s, uint32(cv), false)) + } + if dict != nil { + candidateL := dict.bestTableLong[hashL] + candidateS := dict.bestTableShort[hashS] + best = bestOf(best, matchDict(int(candidateL&0xffff), s, uint32(cv), false)) + best = bestOf(best, matchDict(int(candidateL>>16), s, uint32(cv), false)) + best = bestOf(best, matchDict(int(candidateS&0xffff), s, uint32(cv), false)) + best = bestOf(best, matchDict(int(candidateS>>16), s, uint32(cv), false)) + } { - best = bestOf(best, matchAt(s-repeat+1, s+1, uint32(cv>>8), true)) + if (dict == nil || repeat <= s) && repeat > 0 { + best = bestOf(best, matchAt(s-repeat+1, s+1, uint32(cv>>8), true)) + } else if s-repeat < -4 && dict != nil { + candidate := len(dict.dict) - (repeat - s) + best = bestOf(best, matchDict(candidate, s, uint32(cv), true)) + candidate++ + best = bestOf(best, matchDict(candidate, s+1, uint32(cv>>8), true)) + } + if best.length > 0 { + hashS := hash4(cv>>8, sTableBits) // s+1 - nextShort := sTable[hash4(cv>>8, sTableBits)] + nextShort := sTable[hashS] s := s + 1 cv := load64(src, s) - nextLong := lTable[hash8(cv, lTableBits)] + hashL := hash8(cv, lTableBits) + nextLong := lTable[hashL] best = bestOf(best, matchAt(getCur(nextShort), s, uint32(cv), false)) best = bestOf(best, matchAt(getPrev(nextShort), s, uint32(cv), false)) best = bestOf(best, matchAt(getCur(nextLong), s, uint32(cv), false)) best = bestOf(best, matchAt(getPrev(nextLong), s, uint32(cv), false)) - // Repeat at + 2 - best = bestOf(best, matchAt(s-repeat+1, s+1, uint32(cv>>8), true)) + + // Dict at + 1 + if dict != nil { + candidateL := dict.bestTableLong[hashL] + candidateS := dict.bestTableShort[hashS] + + best = bestOf(best, matchDict(int(candidateL&0xffff), s, uint32(cv), false)) + best = bestOf(best, matchDict(int(candidateS&0xffff), s, uint32(cv), false)) + } // s+2 if true { - nextShort = sTable[hash4(cv>>8, sTableBits)] + hashS := hash4(cv>>8, sTableBits) + + nextShort = sTable[hashS] s++ cv = load64(src, s) - nextLong = lTable[hash8(cv, lTableBits)] + hashL := hash8(cv, lTableBits) + nextLong = lTable[hashL] + + if (dict == nil || repeat <= s) && repeat > 0 { + // Repeat at + 2 + best = bestOf(best, matchAt(s-repeat, s, uint32(cv), true)) + } else if repeat-s > 4 && dict != nil { + candidate := len(dict.dict) - (repeat - s) + best = bestOf(best, matchDict(candidate, s, uint32(cv), true)) + } best = bestOf(best, matchAt(getCur(nextShort), s, uint32(cv), false)) best = bestOf(best, matchAt(getPrev(nextShort), s, uint32(cv), false)) best = bestOf(best, matchAt(getCur(nextLong), s, uint32(cv), false)) best = bestOf(best, matchAt(getPrev(nextLong), s, uint32(cv), false)) + + // Dict at +2 + // Very small gain + if dict != nil { + candidateL := dict.bestTableLong[hashL] + candidateS := dict.bestTableShort[hashS] + + best = bestOf(best, matchDict(int(candidateL&0xffff), s, uint32(cv), false)) + best = bestOf(best, matchDict(int(candidateS&0xffff), s, uint32(cv), false)) + } } // Search for a match at best match end, see if that is better. // Allow some bytes at the beginning to mismatch. @@ -227,7 +356,7 @@ func encodeBlockBest(dst, src []byte) (d int) { // Extend backwards, not needed for repeats... s = best.s - if !best.rep { + if !best.rep && !best.dict { for best.offset > 0 && s > nextEmit && src[best.offset-1] == src[s-1] { best.offset-- best.length++ @@ -244,7 +373,6 @@ func encodeBlockBest(dst, src []byte) (d int) { base := s offset := s - best.offset - s += best.length if offset > 65535 && s-base <= 5 && !best.rep { @@ -256,16 +384,28 @@ func encodeBlockBest(dst, src []byte) (d int) { cv = load64(src, s) continue } + if debug && nextEmit != base { + fmt.Println("EMIT", base-nextEmit, "literals. base-after:", base) + } d += emitLiteral(dst[d:], src[nextEmit:base]) if best.rep { - if nextEmit > 0 { + if nextEmit > 0 || best.dict { + if debug { + fmt.Println("REPEAT, length", best.length, "offset:", offset, "s-after:", s, "dict:", best.dict, "best:", best) + } // same as `add := emitCopy(dst[d:], repeat, s-base)` but skips storing offset. d += emitRepeat(dst[d:], offset, best.length) } else { - // First match, cannot be repeat. + // First match without dict cannot be a repeat. + if debug { + fmt.Println("COPY, length", best.length, "offset:", offset, "s-after:", s, "dict:", best.dict, "best:", best) + } d += emitCopy(dst[d:], offset, best.length) } } else { + if debug { + fmt.Println("COPY, length", best.length, "offset:", offset, "s-after:", s, "dict:", best.dict, "best:", best) + } d += emitCopy(dst[d:], offset, best.length) } repeat = offset @@ -296,6 +436,9 @@ emitRemainder: if d+len(src)-nextEmit > dstLimit { return 0 } + if debug && nextEmit != s { + fmt.Println("emitted ", len(src)-nextEmit, "literals") + } d += emitLiteral(dst[d:], src[nextEmit:]) } return d @@ -642,7 +785,6 @@ func emitRepeatSize(offset, length int) int { left := 0 if length > maxRepeat { left = length - maxRepeat + 4 - length = maxRepeat - 4 } if left > 0 { return 5 + emitRepeatSize(offset, left) diff --git a/s2/encode_better.go b/s2/encode_better.go index 3b66ba42bf..f46adb4117 100644 --- a/s2/encode_better.go +++ b/s2/encode_better.go @@ -6,6 +6,8 @@ package s2 import ( + "bytes" + "fmt" "math/bits" ) @@ -476,3 +478,623 @@ emitRemainder: } return d } + +// encodeBlockBetterDict encodes a non-empty src to a guaranteed-large-enough dst. It +// assumes that the varint-encoded length of the decompressed bytes has already +// been written. +// +// It also assumes that: +// +// len(dst) >= MaxEncodedLen(len(src)) && +// minNonLiteralBlockSize <= len(src) && len(src) <= maxBlockSize +func encodeBlockBetterDict(dst, src []byte, dict *Dict) (d int) { + // sLimit is when to stop looking for offset/length copies. The inputMargin + // lets us use a fast path for emitLiteral in the main loop, while we are + // looking for copies. + // Initialize the hash tables. + const ( + // Long hash matches. + lTableBits = 17 + maxLTableSize = 1 << lTableBits + + // Short hash matches. + sTableBits = 14 + maxSTableSize = 1 << sTableBits + + maxAhead = 8 // maximum bytes ahead without checking sLimit + + debug = false + ) + + sLimit := len(src) - inputMargin + if sLimit > MaxDictSrcOffset-maxAhead { + sLimit = MaxDictSrcOffset - maxAhead + } + if len(src) < minNonLiteralBlockSize { + return 0 + } + + dict.initBetter() + + var lTable [maxLTableSize]uint32 + var sTable [maxSTableSize]uint32 + + // Bail if we can't compress to at least this. + dstLimit := len(src) - len(src)>>5 - 6 + + // nextEmit is where in src the next emitLiteral should start from. + nextEmit := 0 + + // The encoded form must start with a literal, as there are no previous + // bytes to copy, so we start looking for hash matches at s == 1. + s := 0 + cv := load64(src, s) + + // We initialize repeat to 0, so we never match on first attempt + repeat := len(dict.dict) - dict.repeat + + // While in dict +searchDict: + for { + candidateL := 0 + nextS := 0 + for { + // Next src position to check + nextS = s + (s-nextEmit)>>7 + 1 + if nextS > sLimit { + break searchDict + } + hashL := hash7(cv, lTableBits) + hashS := hash4(cv, sTableBits) + candidateL = int(lTable[hashL]) + candidateS := int(sTable[hashS]) + dictL := int(dict.betterTableLong[hashL]) + dictS := int(dict.betterTableShort[hashS]) + lTable[hashL] = uint32(s) + sTable[hashS] = uint32(s) + + valLong := load64(src, candidateL) + valShort := load64(src, candidateS) + + // If long matches at least 8 bytes, use that. + if s != 0 { + if cv == valLong { + goto emitMatch + } + if cv == valShort { + candidateL = candidateS + goto emitMatch + } + } + + // Check dict repeat. + if repeat >= s+4 { + candidate := len(dict.dict) - repeat + s + if candidate > 0 && uint32(cv) == load32(dict.dict, candidate) { + // Extend back + base := s + for i := candidate; base > nextEmit && i > 0 && dict.dict[i-1] == src[base-1]; { + i-- + base-- + } + d += emitLiteral(dst[d:], src[nextEmit:base]) + if debug && nextEmit != base { + fmt.Println("emitted ", base-nextEmit, "literals") + } + s += 4 + candidate += 4 + for candidate < len(dict.dict)-8 && s <= len(src)-8 { + if diff := load64(src, s) ^ load64(dict.dict, candidate); diff != 0 { + s += bits.TrailingZeros64(diff) >> 3 + break + } + s += 8 + candidate += 8 + } + d += emitRepeat(dst[d:], repeat, s-base) + if debug { + fmt.Println("emitted dict repeat length", s-base, "offset:", repeat, "s:", s) + } + nextEmit = s + if s >= sLimit { + break searchDict + } + cv = load64(src, s) + // Index in-between + index0 := base + 1 + index1 := s - 2 + + cv = load64(src, s) + for index0 < index1 { + cv0 := load64(src, index0) + cv1 := load64(src, index1) + lTable[hash7(cv0, lTableBits)] = uint32(index0) + sTable[hash4(cv0>>8, sTableBits)] = uint32(index0 + 1) + + lTable[hash7(cv1, lTableBits)] = uint32(index1) + sTable[hash4(cv1>>8, sTableBits)] = uint32(index1 + 1) + index0 += 2 + index1 -= 2 + } + continue + } + } + // Don't try to find match at s==0 + if s == 0 { + cv = load64(src, nextS) + s = nextS + continue + } + + // Long likely matches 7, so take that. + if uint32(cv) == uint32(valLong) { + goto emitMatch + } + + // Long dict... + if uint32(cv) == load32(dict.dict, dictL) { + candidateL = dictL + goto emitDict + } + + // Check our short candidate + if uint32(cv) == uint32(valShort) { + // Try a long candidate at s+1 + hashL = hash7(cv>>8, lTableBits) + candidateL = int(lTable[hashL]) + lTable[hashL] = uint32(s + 1) + if uint32(cv>>8) == load32(src, candidateL) { + s++ + goto emitMatch + } + // Use our short candidate. + candidateL = candidateS + goto emitMatch + } + if uint32(cv) == load32(dict.dict, dictS) { + // Try a long candidate at s+1 + hashL = hash7(cv>>8, lTableBits) + candidateL = int(lTable[hashL]) + lTable[hashL] = uint32(s + 1) + if uint32(cv>>8) == load32(src, candidateL) { + s++ + goto emitMatch + } + candidateL = dictS + goto emitDict + } + cv = load64(src, nextS) + s = nextS + } + emitDict: + { + if debug { + if load32(dict.dict, candidateL) != load32(src, s) { + panic("dict emit mismatch") + } + } + // Extend backwards. + // The top bytes will be rechecked to get the full match. + for candidateL > 0 && s > nextEmit && dict.dict[candidateL-1] == src[s-1] { + candidateL-- + s-- + } + + // Bail if we exceed the maximum size. + if d+(s-nextEmit) > dstLimit { + return 0 + } + + // A 4-byte match has been found. We'll later see if more than 4 bytes + // match. But, prior to the match, src[nextEmit:s] are unmatched. Emit + // them as literal bytes. + + d += emitLiteral(dst[d:], src[nextEmit:s]) + if debug && nextEmit != s { + fmt.Println("emitted ", s-nextEmit, "literals") + } + { + // Invariant: we have a 4-byte match at s, and no need to emit any + // literal bytes prior to s. + base := s + offset := s + (len(dict.dict)) - candidateL + + // Extend the 4-byte match as long as possible. + s += 4 + candidateL += 4 + for s <= len(src)-8 && len(dict.dict)-candidateL >= 8 { + if diff := load64(src, s) ^ load64(dict.dict, candidateL); diff != 0 { + s += bits.TrailingZeros64(diff) >> 3 + break + } + s += 8 + candidateL += 8 + } + + if repeat == offset { + if debug { + fmt.Println("emitted dict repeat, length", s-base, "offset:", offset, "s:", s, "dict offset:", candidateL) + } + d += emitRepeat(dst[d:], offset, s-base) + } else { + if debug { + fmt.Println("emitted dict copy, length", s-base, "offset:", offset, "s:", s, "dict offset:", candidateL) + } + // Matches longer than 64 are split. + if s <= sLimit || s-base < 8 { + d += emitCopy(dst[d:], offset, s-base) + } else { + // Split to ensure we don't start a copy within next block. + d += emitCopy(dst[d:], offset, 4) + d += emitRepeat(dst[d:], offset, s-base-4) + } + repeat = offset + } + if false { + // Validate match. + if s <= candidateL { + panic("s <= candidate") + } + a := src[base:s] + b := dict.dict[base-repeat : base-repeat+(s-base)] + if !bytes.Equal(a, b) { + panic("mismatch") + } + } + + nextEmit = s + if s >= sLimit { + break searchDict + } + + if d > dstLimit { + // Do we have space for more, if not bail. + return 0 + } + + // Index short & long + index0 := base + 1 + index1 := s - 2 + + cv0 := load64(src, index0) + cv1 := load64(src, index1) + lTable[hash7(cv0, lTableBits)] = uint32(index0) + sTable[hash4(cv0>>8, sTableBits)] = uint32(index0 + 1) + + lTable[hash7(cv1, lTableBits)] = uint32(index1) + sTable[hash4(cv1>>8, sTableBits)] = uint32(index1 + 1) + index0 += 1 + index1 -= 1 + cv = load64(src, s) + + // index every second long in between. + for index0 < index1 { + lTable[hash7(load64(src, index0), lTableBits)] = uint32(index0) + lTable[hash7(load64(src, index1), lTableBits)] = uint32(index1) + index0 += 2 + index1 -= 2 + } + } + continue + } + emitMatch: + + // Extend backwards + for candidateL > 0 && s > nextEmit && src[candidateL-1] == src[s-1] { + candidateL-- + s-- + } + + // Bail if we exceed the maximum size. + if d+(s-nextEmit) > dstLimit { + return 0 + } + + base := s + offset := base - candidateL + + // Extend the 4-byte match as long as possible. + s += 4 + candidateL += 4 + for s < len(src) { + if len(src)-s < 8 { + if src[s] == src[candidateL] { + s++ + candidateL++ + continue + } + break + } + if diff := load64(src, s) ^ load64(src, candidateL); diff != 0 { + s += bits.TrailingZeros64(diff) >> 3 + break + } + s += 8 + candidateL += 8 + } + + if offset > 65535 && s-base <= 5 && repeat != offset { + // Bail if the match is equal or worse to the encoding. + s = nextS + 1 + if s >= sLimit { + goto emitRemainder + } + cv = load64(src, s) + continue + } + + d += emitLiteral(dst[d:], src[nextEmit:base]) + if debug && nextEmit != s { + fmt.Println("emitted ", s-nextEmit, "literals") + } + if repeat == offset { + if debug { + fmt.Println("emitted match repeat, length", s-base, "offset:", offset, "s:", s) + } + d += emitRepeat(dst[d:], offset, s-base) + } else { + if debug { + fmt.Println("emitted match copy, length", s-base, "offset:", offset, "s:", s) + } + d += emitCopy(dst[d:], offset, s-base) + repeat = offset + } + + nextEmit = s + if s >= sLimit { + goto emitRemainder + } + + if d > dstLimit { + // Do we have space for more, if not bail. + return 0 + } + + // Index short & long + index0 := base + 1 + index1 := s - 2 + + cv0 := load64(src, index0) + cv1 := load64(src, index1) + lTable[hash7(cv0, lTableBits)] = uint32(index0) + sTable[hash4(cv0>>8, sTableBits)] = uint32(index0 + 1) + + lTable[hash7(cv1, lTableBits)] = uint32(index1) + sTable[hash4(cv1>>8, sTableBits)] = uint32(index1 + 1) + index0 += 1 + index1 -= 1 + cv = load64(src, s) + + // index every second long in between. + for index0 < index1 { + lTable[hash7(load64(src, index0), lTableBits)] = uint32(index0) + lTable[hash7(load64(src, index1), lTableBits)] = uint32(index1) + index0 += 2 + index1 -= 2 + } + } + + // Search without dict: + if repeat > s { + repeat = 0 + } + + // No more dict + sLimit = len(src) - inputMargin + if s >= sLimit { + goto emitRemainder + } + cv = load64(src, s) + if debug { + fmt.Println("now", s, "->", sLimit, "out:", d, "left:", len(src)-s, "nextemit:", nextEmit, "dstLimit:", dstLimit, "s:", s) + } + for { + candidateL := 0 + nextS := 0 + for { + // Next src position to check + nextS = s + (s-nextEmit)>>7 + 1 + if nextS > sLimit { + goto emitRemainder + } + hashL := hash7(cv, lTableBits) + hashS := hash4(cv, sTableBits) + candidateL = int(lTable[hashL]) + candidateS := int(sTable[hashS]) + lTable[hashL] = uint32(s) + sTable[hashS] = uint32(s) + + valLong := load64(src, candidateL) + valShort := load64(src, candidateS) + + // If long matches at least 8 bytes, use that. + if cv == valLong { + break + } + if cv == valShort { + candidateL = candidateS + break + } + + // Check repeat at offset checkRep. + const checkRep = 1 + // Minimum length of a repeat. Tested with various values. + // While 4-5 offers improvements in some, 6 reduces + // regressions significantly. + const wantRepeatBytes = 6 + const repeatMask = ((1 << (wantRepeatBytes * 8)) - 1) << (8 * checkRep) + if false && repeat > 0 && cv&repeatMask == load64(src, s-repeat)&repeatMask { + base := s + checkRep + // Extend back + for i := base - repeat; base > nextEmit && i > 0 && src[i-1] == src[base-1]; { + i-- + base-- + } + d += emitLiteral(dst[d:], src[nextEmit:base]) + + // Extend forward + candidate := s - repeat + wantRepeatBytes + checkRep + s += wantRepeatBytes + checkRep + for s < len(src) { + if len(src)-s < 8 { + if src[s] == src[candidate] { + s++ + candidate++ + continue + } + break + } + if diff := load64(src, s) ^ load64(src, candidate); diff != 0 { + s += bits.TrailingZeros64(diff) >> 3 + break + } + s += 8 + candidate += 8 + } + // same as `add := emitCopy(dst[d:], repeat, s-base)` but skips storing offset. + d += emitRepeat(dst[d:], repeat, s-base) + nextEmit = s + if s >= sLimit { + goto emitRemainder + } + // Index in-between + index0 := base + 1 + index1 := s - 2 + + cv = load64(src, s) + for index0 < index1 { + cv0 := load64(src, index0) + cv1 := load64(src, index1) + lTable[hash7(cv0, lTableBits)] = uint32(index0) + sTable[hash4(cv0>>8, sTableBits)] = uint32(index0 + 1) + + lTable[hash7(cv1, lTableBits)] = uint32(index1) + sTable[hash4(cv1>>8, sTableBits)] = uint32(index1 + 1) + index0 += 2 + index1 -= 2 + } + + cv = load64(src, s) + continue + } + + // Long likely matches 7, so take that. + if uint32(cv) == uint32(valLong) { + break + } + + // Check our short candidate + if uint32(cv) == uint32(valShort) { + // Try a long candidate at s+1 + hashL = hash7(cv>>8, lTableBits) + candidateL = int(lTable[hashL]) + lTable[hashL] = uint32(s + 1) + if uint32(cv>>8) == load32(src, candidateL) { + s++ + break + } + // Use our short candidate. + candidateL = candidateS + break + } + + cv = load64(src, nextS) + s = nextS + } + + // Extend backwards + for candidateL > 0 && s > nextEmit && src[candidateL-1] == src[s-1] { + candidateL-- + s-- + } + + // Bail if we exceed the maximum size. + if d+(s-nextEmit) > dstLimit { + return 0 + } + + base := s + offset := base - candidateL + + // Extend the 4-byte match as long as possible. + s += 4 + candidateL += 4 + for s < len(src) { + if len(src)-s < 8 { + if src[s] == src[candidateL] { + s++ + candidateL++ + continue + } + break + } + if diff := load64(src, s) ^ load64(src, candidateL); diff != 0 { + s += bits.TrailingZeros64(diff) >> 3 + break + } + s += 8 + candidateL += 8 + } + + if offset > 65535 && s-base <= 5 && repeat != offset { + // Bail if the match is equal or worse to the encoding. + s = nextS + 1 + if s >= sLimit { + goto emitRemainder + } + cv = load64(src, s) + continue + } + + d += emitLiteral(dst[d:], src[nextEmit:base]) + if repeat == offset { + d += emitRepeat(dst[d:], offset, s-base) + } else { + d += emitCopy(dst[d:], offset, s-base) + repeat = offset + } + + nextEmit = s + if s >= sLimit { + goto emitRemainder + } + + if d > dstLimit { + // Do we have space for more, if not bail. + return 0 + } + + // Index short & long + index0 := base + 1 + index1 := s - 2 + + cv0 := load64(src, index0) + cv1 := load64(src, index1) + lTable[hash7(cv0, lTableBits)] = uint32(index0) + sTable[hash4(cv0>>8, sTableBits)] = uint32(index0 + 1) + + lTable[hash7(cv1, lTableBits)] = uint32(index1) + sTable[hash4(cv1>>8, sTableBits)] = uint32(index1 + 1) + index0 += 1 + index1 -= 1 + cv = load64(src, s) + + // index every second long in between. + for index0 < index1 { + lTable[hash7(load64(src, index0), lTableBits)] = uint32(index0) + lTable[hash7(load64(src, index1), lTableBits)] = uint32(index1) + index0 += 2 + index1 -= 2 + } + } + +emitRemainder: + if nextEmit < len(src) { + // Bail if we exceed the maximum size. + if d+len(src)-nextEmit > dstLimit { + return 0 + } + d += emitLiteral(dst[d:], src[nextEmit:]) + } + return d +} diff --git a/s2/examples_test.go b/s2/examples_test.go new file mode 100644 index 0000000000..3282ee8b8c --- /dev/null +++ b/s2/examples_test.go @@ -0,0 +1,102 @@ +// Copyright (c) 2023+ Klaus Post. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package s2_test + +import ( + "bytes" + "fmt" + "os" + + "github.com/klauspost/compress/s2" + "github.com/klauspost/compress/zstd" +) + +func ExampleMakeDict() { + // Read a sample + sample, err := os.ReadFile("../testdata/gettysburg.txt") + if err != nil { + panic(err) + } + fmt.Println("Input size:", len(sample)) + + // Create a dictionary. + dict := s2.MakeDict(sample, nil) + fmt.Println("Dict size:", len(dict.Bytes())) + + encoded := dict.Encode(nil, sample) + if len(encoded) < 20 { + fmt.Println("Encoded size was less than 20 bytes!") + } + + // To decode: + decoded, err := dict.Decode(nil, encoded) + if err != nil { + panic(err) + } + if bytes.Equal(decoded, sample) { + fmt.Println("They match!") + } + // OUTPUT: + // Input size: 1548 + // Dict size: 1549 + // Encoded size was less than 20 bytes! + // They match! +} + +func ExampleMakeDict_zstd() { + // Read dictionary generated by zStandard using the command line + // λ zstd -r --train-fastcover -o zstd.dict --maxdict=2048 gosrc\* + // With gosrc containing all the standard library source files. + zdict := []byte("7\xa40콶\xc1\x1bB\x10\x982\xc4\xe9\xc0\xc0\xc0\xc0\xc0\xc0\xc0\xc0\xc0\xc0\xc0\xc0@\xf5<\xda#\"{\xb7\xb6\xdd\xdd\xda\x17\x1b\t\x9b\xbd\x13n{U\xc1k\x11\xc3\x1b\x8b\xfbX\xee\xfe\xcb1\xcai\f\xf6meE\x97\x19\x83\\f\x14\x00\\\tS\x01\x00\x18 \x18\x8f\aT\x1a\xf5\x00\x00\x04\x80O\xd3MIJH\x03q\x98$I\n\xa3\x10B\xc6\x18B\b\x01\x00\x00D\x00\x04\x04\x00\xc0\x00\x00\x004\xcdieĩ@Β \xc7\x14B\n͌\b\x00\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\b\x00\x00\x00kage types2\n\nimport (\n\t\"cmd/compile/internal/syntax\"\n\t\"strings\"\n\t\"unicode\"\n)\n\n// funcInst type-checks a func\")\n\tif err != nil {\n\t\tt.Fatalf(\"Prepare: %v\", err)\n\t}\n\tdefer stmt.Close()\n\n\tconst n = 10\n\tch := make(chan error, n)\n\tfor i := 0; i < n; i++ {\n\t\tgo func() {\n\t\t\tvar age int\n\t\t\terr := stmt.QueryRowool { return c != nil && c.fd != nil }\n\n// Implementation of the Conn interface.\n\n// Read implements the Conn Read method.\nfunc (c *conn) Read(b []byte) (int, error) {\n\tif !c.ok() {\n\t\treturn 0, t\n\t\t} else {\n\t\t\treturn nil, &FormatError{0, \"invalid magic number\", nil}\n\t\t}\n\t}\n\toffset := int64(4)\n\n\t// Read the number of FatArchHeaders that come after the fat_header.\n\tvar narch uint32\n\terr log.Fatal(err)\n\t\t}\n\t\tf := strings.Fields(line)\n\t\tif len(f) == 0 {\n\t\t\tcontinue\n\t\t}\n\t\tswitch f[0] {\n\t\tdefault:\n\t\t\tfmt.Fprintf(os.Stderr, \"?unknown command\\n\")\n\t\t\tcontinue\n\t\tcase \"tags\":\n\t\t\tprefix 00\\x00\\x00\", true},\n\t}\n\n\tfor _, v := range vectors {\n\t\tvar f formatter\n\t\tgot := make([]byte, len(v.want))\n\t\tf.formatNumeric(got, v.in)\n\t\tok := (f.err == nil)\n\t\tif ok != v.ok {\n\t\t\tif v.ok {\n\t\t\t\ttturn true\n\t}\n\treturn false\n}\nfunc rewriteValueARM_OpARMBICconst(v *Value) bool {\n\tv_0 := v.Args[0]\n\t// match: (BICconst [0] x)\n\t// result: x\n\tfor {\n\t\tif auxIntToInt32(v.AuxInt) != 0 {\n\t\t\tbreak\n\tnt) {\n\t\t\t\t\tt.Errorf(\"%5g %s %5g = %5s; want %5s\", x, op, y, got, want)\n\t\t\t\t}\n\t\t\t}\n\t\t}\n\t}\n}\n\nfunc TestFloatArithmeticOverflow(t *testing.T) {\n\tfor _, test := range []struct {\n\t\tprec uint\n\t\t)\n\t\t\t}\n\t\t\treturn\n\t\t}\n\t}\n}\n// Copyright 2017 The Go Authors. All rights reserved.\n// Use of this source code is governed by a BSD-style\n// license that can be found in the LICENSE file.\n\npackage ), uintptr(unsafe.Pointer(_p1)), 0)\n\tif e1 != 0 {\n\t\terr = errnoErr(e1)\n\t}\n\treturn\n}\n\n// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT\n\nfunc Sync() (err error) {\n\t_, _, e1 := SyscDLINK = 0x10\n\tMOVEFILE_FAIL_IF_NOT") + + // Decode the zstandard dictionary. + insp, err := zstd.InspectDictionary(zdict) + if err != nil { + panic(err) + } + + // We are only interested in the contents. + fmt.Println("Dictionary content length:", len(insp.Content())) + + // Create a dictionary. + // Assume that files start with "// Copyright (c) 2023". + // Search for the longest match for that. + // This may save a few bytes. + dict := s2.MakeDict(insp.Content(), []byte("// Copyright (c) 2023")) + + // b := d.Bytes() will provide a dictionary that can be saved + // and reloaded with s2.NewDict(b). + + fmt.Println("Dict size:", len(dict.Bytes())) + + // Read a sample. Use this file. + sample, err := os.ReadFile("examples_test.go") + if err != nil { + panic(err) + } + + encodedWithDict := dict.Encode(nil, sample) + encodedNoDict := s2.Encode(nil, sample) + + // Print a less accurate output that is less likely to change. + // Since we include the (encoded) dictionary itself that will create better than expected compression. + if len(encodedWithDict) < len(encodedNoDict)-1000 { + fmt.Println("Saved more than 1000 bytes") + } + + // To decode the content: + decoded, err := dict.Decode(nil, encodedWithDict) + if err != nil { + panic(err) + } + if bytes.Equal(decoded, sample) { + fmt.Println("They match!") + } + // OUTPUT: + // Dictionary content length: 1894 + // Dict size: 1896 + // Saved more than 1000 bytes + // They match! +} diff --git a/s2/testdata/fuzz/block-corpus-enc.zip b/s2/testdata/fuzz/block-corpus-enc.zip index 1e44ebcefb..7f0832c835 100644 Binary files a/s2/testdata/fuzz/block-corpus-enc.zip and b/s2/testdata/fuzz/block-corpus-enc.zip differ diff --git a/s2/testdata/s2-dict.bin.gz b/s2/testdata/s2-dict.bin.gz new file mode 100644 index 0000000000..f68dac7e48 Binary files /dev/null and b/s2/testdata/s2-dict.bin.gz differ diff --git a/zstd/dict.go b/zstd/dict.go index 66a95c18ef..ca0951452e 100644 --- a/zstd/dict.go +++ b/zstd/dict.go @@ -32,14 +32,38 @@ func (d *dict) ID() uint32 { return d.id } -// DictContentSize returns the dictionary content size or 0 if d is nil. -func (d *dict) DictContentSize() int { +// ContentSize returns the dictionary content size or 0 if d is nil. +func (d *dict) ContentSize() int { if d == nil { return 0 } return len(d.content) } +// Content returns the dictionary content. +func (d *dict) Content() []byte { + if d == nil { + return nil + } + return d.content +} + +// Offsets returns the initial offsets. +func (d *dict) Offsets() [3]int { + if d == nil { + return [3]int{} + } + return d.offsets +} + +// LitEncoder returns the literal encoder. +func (d *dict) LitEncoder() *huff0.Scratch { + if d == nil { + return nil + } + return d.litEnc +} + // Load a dictionary as described in // https://github.com/facebook/zstd/blob/master/doc/zstd_compression_format.md#dictionary-format func loadDict(b []byte) (*dict, error) { @@ -64,7 +88,7 @@ func loadDict(b []byte) (*dict, error) { var err error d.litEnc, b, err = huff0.ReadTable(b[8:], nil) if err != nil { - return nil, err + return nil, fmt.Errorf("loading literal table: %w", err) } d.litEnc.Reuse = huff0.ReusePolicyMust @@ -122,3 +146,16 @@ func loadDict(b []byte) (*dict, error) { return &d, nil } + +// InspectDictionary loads a zstd dictionary and provides functions to inspect the content. +func InspectDictionary(b []byte) (interface { + ID() uint32 + ContentSize() int + Content() []byte + Offsets() [3]int + LitEncoder() *huff0.Scratch +}, error) { + initPredefined() + d, err := loadDict(b) + return d, err +} diff --git a/zstd/enc_base.go b/zstd/enc_base.go index bfb2e146c3..e008b99298 100644 --- a/zstd/enc_base.go +++ b/zstd/enc_base.go @@ -149,7 +149,7 @@ func (e *fastBase) resetBase(d *dict, singleBlock bool) { if singleBlock { e.lowMem = true } - e.ensureHist(d.DictContentSize() + maxCompressedBlockSize) + e.ensureHist(d.ContentSize() + maxCompressedBlockSize) e.lowMem = low }