Skip to content

Commit

Permalink
Support cgozstd dictionaries
Browse files Browse the repository at this point in the history
Updates #23
  • Loading branch information
nigeltao committed Oct 21, 2019
1 parent 7b0391f commit 2fd8649
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 11 deletions.
124 changes: 115 additions & 9 deletions lib/cgozstd/cgozstd.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,66 @@
// Package cgozstd wraps the C "zstd" library.
package cgozstd

// TODO: dictionaries. See https://github.com/facebook/zstd/issues/1776

/*
#cgo pkg-config: libzstd
#include "zstd.h"
#include "zstd_errors.h"
#include <stdint.h>
// --------
#if (ZSTD_VERSION_MAJOR < 1) || (ZSTD_VERSION_MINOR < 3)
int32_t cgozstd_compress_start(ZSTD_CCtx* z,
uint8_t* dict_ptr,
uint32_t dict_len,
int compression_level) {
if (dict_len > 0) {
return -1;
}
return ZSTD_getErrorCode(ZSTD_initCStream(z, compression_level));
}
int32_t cgozstd_decompress_start(ZSTD_DCtx* z,
uint8_t* dict_ptr,
uint32_t dict_len) {
if (dict_len > 0) {
return -1;
}
return ZSTD_getErrorCode(ZSTD_initDStream(z));
}
#else
// TODO: don't use the unsupported ZSTD_initFoo_usingDict API.
ZSTDLIB_API size_t ZSTD_initCStream_usingDict(
ZSTD_CCtx* z,
const void* dict_ptr,
size_t dict_len,
int compression_level);
ZSTDLIB_API size_t ZSTD_initDStream_usingDict(
ZSTD_DCtx* z,
const void* dict_ptr,
size_t dict_len);
int32_t cgozstd_compress_start(ZSTD_CCtx* z,
uint8_t* dict_ptr,
uint32_t dict_len,
int compression_level) {
return ZSTD_getErrorCode(ZSTD_initCStream_usingDict(
z, dict_ptr, dict_len, compression_level));
}
int32_t cgozstd_decompress_start(ZSTD_DCtx* z,
uint8_t* dict_ptr,
uint32_t dict_len) {
return ZSTD_getErrorCode(ZSTD_initDStream_usingDict(
z, dict_ptr, dict_len));
}
#endif
// --------
typedef struct {
uint32_t ndst;
uint32_t nsrc;
Expand Down Expand Up @@ -116,11 +167,12 @@ const cgoEnabled = true
const maxLen = 1 << 30

var (
errMissingResetCall = errors.New("cgozstd: missing Reset call")
errNilIOReader = errors.New("cgozstd: nil io.Reader")
errNilIOWriter = errors.New("cgozstd: nil io.Writer")
errNilReceiver = errors.New("cgozstd: nil receiver")
errOutOfMemory = errors.New("cgozstd: out of memory")
errMissingResetCall = errors.New("cgozstd: missing Reset call")
errNilIOReader = errors.New("cgozstd: nil io.Reader")
errNilIOWriter = errors.New("cgozstd: nil io.Writer")
errNilReceiver = errors.New("cgozstd: nil receiver")
errOutOfMemory = errors.New("cgozstd: out of memory")
errZstdVersionTooSmall = errors.New("cgozstd: zstd version too small (1.3 minimum)")
)

type errCode int32
Expand Down Expand Up @@ -165,6 +217,8 @@ type Reader struct {
i, j uint32
r io.Reader

dictionary []byte

readErr error
zstdErr error

Expand All @@ -185,7 +239,12 @@ func (r *Reader) Reset(reader io.Reader, dictionary []byte) error {
if reader == nil {
return errNilIOReader
}
if len(dictionary) > maxLen {
dictionary = dictionary[len(dictionary)-maxLen:]
}

r.r = reader
r.dictionary = dictionary
return nil
}

Expand Down Expand Up @@ -225,13 +284,32 @@ func (r *Reader) Read(p []byte) (int, error) {
if r.z == nil {
if (r.recycler != nil) && !r.recycler.closed && (r.recycler.z != nil) {
r.z, r.recycler.z = r.recycler.z, nil
C.ZSTD_initDStream(r.z)
} else {
r.z = C.ZSTD_createDStream()
if r.z == nil {
return 0, errOutOfMemory
}
}

e := errCode(0)
if len(r.dictionary) == 0 {
e = errCode(C.cgozstd_decompress_start(r.z,
(*C.uint8_t)(nil),
(C.uint32_t)(0),
))
} else {
e = errCode(C.cgozstd_decompress_start(r.z,
(*C.uint8_t)(unsafe.Pointer(&r.dictionary[0])),
(C.uint32_t)(len(r.dictionary)),
))
}
if e < 0 {
r.zstdErr = errZstdVersionTooSmall
return 0, r.zstdErr
} else if e != 0 {
r.zstdErr = e
return 0, r.zstdErr
}
}

if len(p) > maxLen {
Expand Down Expand Up @@ -320,6 +398,8 @@ type Writer struct {
w io.Writer
level compression.Level

dictionary []byte

writeErr error

recycler *WriterRecycler
Expand All @@ -341,8 +421,13 @@ func (w *Writer) Reset(writer io.Writer, dictionary []byte, level compression.Le
if writer == nil {
return errNilIOWriter
}
if len(dictionary) > maxLen {
dictionary = dictionary[len(dictionary)-maxLen:]
}

w.w = writer
w.level = level
w.dictionary = dictionary
return nil
}

Expand Down Expand Up @@ -436,7 +521,28 @@ func (w *Writer) write(p []byte, final bool) error {
return errOutOfMemory
}
}
C.ZSTD_initCStream(w.z, C.int(w.zstdCompressionLevel()))

e := errCode(0)
if len(w.dictionary) == 0 {
e = errCode(C.cgozstd_compress_start(w.z,
(*C.uint8_t)(nil),
(C.uint32_t)(0),
C.int(w.zstdCompressionLevel()),
))
} else {
e = errCode(C.cgozstd_compress_start(w.z,
(*C.uint8_t)(unsafe.Pointer(&w.dictionary[0])),
(C.uint32_t)(len(w.dictionary)),
C.int(w.zstdCompressionLevel()),
))
}
if e < 0 {
w.writeErr = errZstdVersionTooSmall
return w.writeErr
} else if e != 0 {
w.writeErr = e
return w.writeErr
}
}

for (len(p) > 0) || final {
Expand Down
71 changes: 69 additions & 2 deletions lib/cgozstd/cgozstd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ func TestRoundTrip(tt *testing.T) {

// Compress.
{
w.Reset(buf, nil, 0)
if err := w.Reset(buf, nil, 0); err != nil {
w.Close()
tt.Fatalf("i=%d: Reset: %v", i, err)
}
if _, err := w.Write([]byte(uncompressedMore)); err != nil {
w.Close()
tt.Fatalf("i=%d: Write: %v", i, err)
Expand All @@ -80,7 +83,10 @@ func TestRoundTrip(tt *testing.T) {

// Uncompress.
{
r.Reset(strings.NewReader(compressed), nil)
if err := r.Reset(strings.NewReader(compressed), nil); err != nil {
r.Close()
tt.Fatalf("i=%d: Reset: %v", i, err)
}
gotBytes, err := ioutil.ReadAll(r)
if err != nil {
r.Close()
Expand All @@ -96,3 +102,64 @@ func TestRoundTrip(tt *testing.T) {
}
}
}

func TestDictionary(tt *testing.T) {
if !cgoEnabled {
tt.Skip("cgo is not enabled")
}

const (
abc = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
uncompressed = abc + "123"
)

for _, withDict := range []bool{false, true} {
buf := &bytes.Buffer{}
dictionary, name := []byte(nil), "sans dictionary"
if withDict {
dictionary, name = []byte(abc), "with dictionary"
}

w := &Writer{}
if err := w.Reset(buf, dictionary, 0); err != nil {
w.Close()
tt.Fatalf("%s: Reset: %v", name, err)
}
if _, err := w.Write([]byte(uncompressed)); err != nil {
w.Close()
tt.Fatalf("%s: Write: %v", name, err)
}
if err := w.Close(); err != nil {
tt.Fatalf("%s: Close: %v", name, err)
}

compressed := buf.String()
if withDict {
if n := buf.Len(); n >= 30 {
tt.Fatalf("%s: compressed length: got %d, want < 30", name, n)
}
} else {
if n := buf.Len(); n < 50 {
tt.Fatalf("%s: compressed length: got %d, want >= 50", name, n)
}
}

r := &Reader{}
if err := r.Reset(strings.NewReader(compressed), dictionary); err != nil {
r.Close()
tt.Fatalf("%s: Reset: %v", name, err)
}
gotBytes, err := ioutil.ReadAll(r)
if err != nil {
r.Close()
tt.Fatalf("%s: ReadAll: %v", name, err)
}
if got, want := string(gotBytes), uncompressed; got != want {
r.Close()
tt.Fatalf("%s:\ngot %q\nwant %q", name, got, want)
}
if err := r.Close(); err != nil {
tt.Fatalf("%s: Close: %v", name, err)
}
}
}

0 comments on commit 2fd8649

Please sign in to comment.