Skip to content

Commit 2fd8649

Browse files
committed
Support cgozstd dictionaries
Updates #23
1 parent 7b0391f commit 2fd8649

File tree

2 files changed

+184
-11
lines changed

2 files changed

+184
-11
lines changed

lib/cgozstd/cgozstd.go

+115-9
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,66 @@
1717
// Package cgozstd wraps the C "zstd" library.
1818
package cgozstd
1919

20-
// TODO: dictionaries. See https://github.com/facebook/zstd/issues/1776
21-
2220
/*
2321
#cgo pkg-config: libzstd
2422
#include "zstd.h"
2523
#include "zstd_errors.h"
2624
2725
#include <stdint.h>
2826
27+
// --------
28+
#if (ZSTD_VERSION_MAJOR < 1) || (ZSTD_VERSION_MINOR < 3)
29+
30+
int32_t cgozstd_compress_start(ZSTD_CCtx* z,
31+
uint8_t* dict_ptr,
32+
uint32_t dict_len,
33+
int compression_level) {
34+
if (dict_len > 0) {
35+
return -1;
36+
}
37+
return ZSTD_getErrorCode(ZSTD_initCStream(z, compression_level));
38+
}
39+
40+
int32_t cgozstd_decompress_start(ZSTD_DCtx* z,
41+
uint8_t* dict_ptr,
42+
uint32_t dict_len) {
43+
if (dict_len > 0) {
44+
return -1;
45+
}
46+
return ZSTD_getErrorCode(ZSTD_initDStream(z));
47+
}
48+
49+
#else
50+
51+
// TODO: don't use the unsupported ZSTD_initFoo_usingDict API.
52+
ZSTDLIB_API size_t ZSTD_initCStream_usingDict(
53+
ZSTD_CCtx* z,
54+
const void* dict_ptr,
55+
size_t dict_len,
56+
int compression_level);
57+
ZSTDLIB_API size_t ZSTD_initDStream_usingDict(
58+
ZSTD_DCtx* z,
59+
const void* dict_ptr,
60+
size_t dict_len);
61+
62+
int32_t cgozstd_compress_start(ZSTD_CCtx* z,
63+
uint8_t* dict_ptr,
64+
uint32_t dict_len,
65+
int compression_level) {
66+
return ZSTD_getErrorCode(ZSTD_initCStream_usingDict(
67+
z, dict_ptr, dict_len, compression_level));
68+
}
69+
70+
int32_t cgozstd_decompress_start(ZSTD_DCtx* z,
71+
uint8_t* dict_ptr,
72+
uint32_t dict_len) {
73+
return ZSTD_getErrorCode(ZSTD_initDStream_usingDict(
74+
z, dict_ptr, dict_len));
75+
}
76+
77+
#endif
78+
// --------
79+
2980
typedef struct {
3081
uint32_t ndst;
3182
uint32_t nsrc;
@@ -116,11 +167,12 @@ const cgoEnabled = true
116167
const maxLen = 1 << 30
117168

118169
var (
119-
errMissingResetCall = errors.New("cgozstd: missing Reset call")
120-
errNilIOReader = errors.New("cgozstd: nil io.Reader")
121-
errNilIOWriter = errors.New("cgozstd: nil io.Writer")
122-
errNilReceiver = errors.New("cgozstd: nil receiver")
123-
errOutOfMemory = errors.New("cgozstd: out of memory")
170+
errMissingResetCall = errors.New("cgozstd: missing Reset call")
171+
errNilIOReader = errors.New("cgozstd: nil io.Reader")
172+
errNilIOWriter = errors.New("cgozstd: nil io.Writer")
173+
errNilReceiver = errors.New("cgozstd: nil receiver")
174+
errOutOfMemory = errors.New("cgozstd: out of memory")
175+
errZstdVersionTooSmall = errors.New("cgozstd: zstd version too small (1.3 minimum)")
124176
)
125177

126178
type errCode int32
@@ -165,6 +217,8 @@ type Reader struct {
165217
i, j uint32
166218
r io.Reader
167219

220+
dictionary []byte
221+
168222
readErr error
169223
zstdErr error
170224

@@ -185,7 +239,12 @@ func (r *Reader) Reset(reader io.Reader, dictionary []byte) error {
185239
if reader == nil {
186240
return errNilIOReader
187241
}
242+
if len(dictionary) > maxLen {
243+
dictionary = dictionary[len(dictionary)-maxLen:]
244+
}
245+
188246
r.r = reader
247+
r.dictionary = dictionary
189248
return nil
190249
}
191250

@@ -225,13 +284,32 @@ func (r *Reader) Read(p []byte) (int, error) {
225284
if r.z == nil {
226285
if (r.recycler != nil) && !r.recycler.closed && (r.recycler.z != nil) {
227286
r.z, r.recycler.z = r.recycler.z, nil
228-
C.ZSTD_initDStream(r.z)
229287
} else {
230288
r.z = C.ZSTD_createDStream()
231289
if r.z == nil {
232290
return 0, errOutOfMemory
233291
}
234292
}
293+
294+
e := errCode(0)
295+
if len(r.dictionary) == 0 {
296+
e = errCode(C.cgozstd_decompress_start(r.z,
297+
(*C.uint8_t)(nil),
298+
(C.uint32_t)(0),
299+
))
300+
} else {
301+
e = errCode(C.cgozstd_decompress_start(r.z,
302+
(*C.uint8_t)(unsafe.Pointer(&r.dictionary[0])),
303+
(C.uint32_t)(len(r.dictionary)),
304+
))
305+
}
306+
if e < 0 {
307+
r.zstdErr = errZstdVersionTooSmall
308+
return 0, r.zstdErr
309+
} else if e != 0 {
310+
r.zstdErr = e
311+
return 0, r.zstdErr
312+
}
235313
}
236314

237315
if len(p) > maxLen {
@@ -320,6 +398,8 @@ type Writer struct {
320398
w io.Writer
321399
level compression.Level
322400

401+
dictionary []byte
402+
323403
writeErr error
324404

325405
recycler *WriterRecycler
@@ -341,8 +421,13 @@ func (w *Writer) Reset(writer io.Writer, dictionary []byte, level compression.Le
341421
if writer == nil {
342422
return errNilIOWriter
343423
}
424+
if len(dictionary) > maxLen {
425+
dictionary = dictionary[len(dictionary)-maxLen:]
426+
}
427+
344428
w.w = writer
345429
w.level = level
430+
w.dictionary = dictionary
346431
return nil
347432
}
348433

@@ -436,7 +521,28 @@ func (w *Writer) write(p []byte, final bool) error {
436521
return errOutOfMemory
437522
}
438523
}
439-
C.ZSTD_initCStream(w.z, C.int(w.zstdCompressionLevel()))
524+
525+
e := errCode(0)
526+
if len(w.dictionary) == 0 {
527+
e = errCode(C.cgozstd_compress_start(w.z,
528+
(*C.uint8_t)(nil),
529+
(C.uint32_t)(0),
530+
C.int(w.zstdCompressionLevel()),
531+
))
532+
} else {
533+
e = errCode(C.cgozstd_compress_start(w.z,
534+
(*C.uint8_t)(unsafe.Pointer(&w.dictionary[0])),
535+
(C.uint32_t)(len(w.dictionary)),
536+
C.int(w.zstdCompressionLevel()),
537+
))
538+
}
539+
if e < 0 {
540+
w.writeErr = errZstdVersionTooSmall
541+
return w.writeErr
542+
} else if e != 0 {
543+
w.writeErr = e
544+
return w.writeErr
545+
}
440546
}
441547

442548
for (len(p) > 0) || final {

lib/cgozstd/cgozstd_test.go

+69-2
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ func TestRoundTrip(tt *testing.T) {
6262

6363
// Compress.
6464
{
65-
w.Reset(buf, nil, 0)
65+
if err := w.Reset(buf, nil, 0); err != nil {
66+
w.Close()
67+
tt.Fatalf("i=%d: Reset: %v", i, err)
68+
}
6669
if _, err := w.Write([]byte(uncompressedMore)); err != nil {
6770
w.Close()
6871
tt.Fatalf("i=%d: Write: %v", i, err)
@@ -80,7 +83,10 @@ func TestRoundTrip(tt *testing.T) {
8083

8184
// Uncompress.
8285
{
83-
r.Reset(strings.NewReader(compressed), nil)
86+
if err := r.Reset(strings.NewReader(compressed), nil); err != nil {
87+
r.Close()
88+
tt.Fatalf("i=%d: Reset: %v", i, err)
89+
}
8490
gotBytes, err := ioutil.ReadAll(r)
8591
if err != nil {
8692
r.Close()
@@ -96,3 +102,64 @@ func TestRoundTrip(tt *testing.T) {
96102
}
97103
}
98104
}
105+
106+
func TestDictionary(tt *testing.T) {
107+
if !cgoEnabled {
108+
tt.Skip("cgo is not enabled")
109+
}
110+
111+
const (
112+
abc = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
113+
uncompressed = abc + "123"
114+
)
115+
116+
for _, withDict := range []bool{false, true} {
117+
buf := &bytes.Buffer{}
118+
dictionary, name := []byte(nil), "sans dictionary"
119+
if withDict {
120+
dictionary, name = []byte(abc), "with dictionary"
121+
}
122+
123+
w := &Writer{}
124+
if err := w.Reset(buf, dictionary, 0); err != nil {
125+
w.Close()
126+
tt.Fatalf("%s: Reset: %v", name, err)
127+
}
128+
if _, err := w.Write([]byte(uncompressed)); err != nil {
129+
w.Close()
130+
tt.Fatalf("%s: Write: %v", name, err)
131+
}
132+
if err := w.Close(); err != nil {
133+
tt.Fatalf("%s: Close: %v", name, err)
134+
}
135+
136+
compressed := buf.String()
137+
if withDict {
138+
if n := buf.Len(); n >= 30 {
139+
tt.Fatalf("%s: compressed length: got %d, want < 30", name, n)
140+
}
141+
} else {
142+
if n := buf.Len(); n < 50 {
143+
tt.Fatalf("%s: compressed length: got %d, want >= 50", name, n)
144+
}
145+
}
146+
147+
r := &Reader{}
148+
if err := r.Reset(strings.NewReader(compressed), dictionary); err != nil {
149+
r.Close()
150+
tt.Fatalf("%s: Reset: %v", name, err)
151+
}
152+
gotBytes, err := ioutil.ReadAll(r)
153+
if err != nil {
154+
r.Close()
155+
tt.Fatalf("%s: ReadAll: %v", name, err)
156+
}
157+
if got, want := string(gotBytes), uncompressed; got != want {
158+
r.Close()
159+
tt.Fatalf("%s:\ngot %q\nwant %q", name, got, want)
160+
}
161+
if err := r.Close(); err != nil {
162+
tt.Fatalf("%s: Close: %v", name, err)
163+
}
164+
}
165+
}

0 commit comments

Comments
 (0)