17
17
// Package cgozstd wraps the C "zstd" library.
18
18
package cgozstd
19
19
20
- // TODO: dictionaries. See https://github.com/facebook/zstd/issues/1776
21
-
22
20
/*
23
21
#cgo pkg-config: libzstd
24
22
#include "zstd.h"
25
23
#include "zstd_errors.h"
26
24
27
25
#include <stdint.h>
28
26
27
+ // --------
28
+ #if (ZSTD_VERSION_MAJOR < 1) || (ZSTD_VERSION_MINOR < 3)
29
+
30
+ int32_t cgozstd_compress_start(ZSTD_CCtx* z,
31
+ uint8_t* dict_ptr,
32
+ uint32_t dict_len,
33
+ int compression_level) {
34
+ if (dict_len > 0) {
35
+ return -1;
36
+ }
37
+ return ZSTD_getErrorCode(ZSTD_initCStream(z, compression_level));
38
+ }
39
+
40
+ int32_t cgozstd_decompress_start(ZSTD_DCtx* z,
41
+ uint8_t* dict_ptr,
42
+ uint32_t dict_len) {
43
+ if (dict_len > 0) {
44
+ return -1;
45
+ }
46
+ return ZSTD_getErrorCode(ZSTD_initDStream(z));
47
+ }
48
+
49
+ #else
50
+
51
+ // TODO: don't use the unsupported ZSTD_initFoo_usingDict API.
52
+ ZSTDLIB_API size_t ZSTD_initCStream_usingDict(
53
+ ZSTD_CCtx* z,
54
+ const void* dict_ptr,
55
+ size_t dict_len,
56
+ int compression_level);
57
+ ZSTDLIB_API size_t ZSTD_initDStream_usingDict(
58
+ ZSTD_DCtx* z,
59
+ const void* dict_ptr,
60
+ size_t dict_len);
61
+
62
+ int32_t cgozstd_compress_start(ZSTD_CCtx* z,
63
+ uint8_t* dict_ptr,
64
+ uint32_t dict_len,
65
+ int compression_level) {
66
+ return ZSTD_getErrorCode(ZSTD_initCStream_usingDict(
67
+ z, dict_ptr, dict_len, compression_level));
68
+ }
69
+
70
+ int32_t cgozstd_decompress_start(ZSTD_DCtx* z,
71
+ uint8_t* dict_ptr,
72
+ uint32_t dict_len) {
73
+ return ZSTD_getErrorCode(ZSTD_initDStream_usingDict(
74
+ z, dict_ptr, dict_len));
75
+ }
76
+
77
+ #endif
78
+ // --------
79
+
29
80
typedef struct {
30
81
uint32_t ndst;
31
82
uint32_t nsrc;
@@ -116,11 +167,12 @@ const cgoEnabled = true
116
167
const maxLen = 1 << 30
117
168
118
169
var (
119
- errMissingResetCall = errors .New ("cgozstd: missing Reset call" )
120
- errNilIOReader = errors .New ("cgozstd: nil io.Reader" )
121
- errNilIOWriter = errors .New ("cgozstd: nil io.Writer" )
122
- errNilReceiver = errors .New ("cgozstd: nil receiver" )
123
- errOutOfMemory = errors .New ("cgozstd: out of memory" )
170
+ errMissingResetCall = errors .New ("cgozstd: missing Reset call" )
171
+ errNilIOReader = errors .New ("cgozstd: nil io.Reader" )
172
+ errNilIOWriter = errors .New ("cgozstd: nil io.Writer" )
173
+ errNilReceiver = errors .New ("cgozstd: nil receiver" )
174
+ errOutOfMemory = errors .New ("cgozstd: out of memory" )
175
+ errZstdVersionTooSmall = errors .New ("cgozstd: zstd version too small (1.3 minimum)" )
124
176
)
125
177
126
178
type errCode int32
@@ -165,6 +217,8 @@ type Reader struct {
165
217
i , j uint32
166
218
r io.Reader
167
219
220
+ dictionary []byte
221
+
168
222
readErr error
169
223
zstdErr error
170
224
@@ -185,7 +239,12 @@ func (r *Reader) Reset(reader io.Reader, dictionary []byte) error {
185
239
if reader == nil {
186
240
return errNilIOReader
187
241
}
242
+ if len (dictionary ) > maxLen {
243
+ dictionary = dictionary [len (dictionary )- maxLen :]
244
+ }
245
+
188
246
r .r = reader
247
+ r .dictionary = dictionary
189
248
return nil
190
249
}
191
250
@@ -225,13 +284,32 @@ func (r *Reader) Read(p []byte) (int, error) {
225
284
if r .z == nil {
226
285
if (r .recycler != nil ) && ! r .recycler .closed && (r .recycler .z != nil ) {
227
286
r .z , r .recycler .z = r .recycler .z , nil
228
- C .ZSTD_initDStream (r .z )
229
287
} else {
230
288
r .z = C .ZSTD_createDStream ()
231
289
if r .z == nil {
232
290
return 0 , errOutOfMemory
233
291
}
234
292
}
293
+
294
+ e := errCode (0 )
295
+ if len (r .dictionary ) == 0 {
296
+ e = errCode (C .cgozstd_decompress_start (r .z ,
297
+ (* C .uint8_t )(nil ),
298
+ (C .uint32_t )(0 ),
299
+ ))
300
+ } else {
301
+ e = errCode (C .cgozstd_decompress_start (r .z ,
302
+ (* C .uint8_t )(unsafe .Pointer (& r .dictionary [0 ])),
303
+ (C .uint32_t )(len (r .dictionary )),
304
+ ))
305
+ }
306
+ if e < 0 {
307
+ r .zstdErr = errZstdVersionTooSmall
308
+ return 0 , r .zstdErr
309
+ } else if e != 0 {
310
+ r .zstdErr = e
311
+ return 0 , r .zstdErr
312
+ }
235
313
}
236
314
237
315
if len (p ) > maxLen {
@@ -320,6 +398,8 @@ type Writer struct {
320
398
w io.Writer
321
399
level compression.Level
322
400
401
+ dictionary []byte
402
+
323
403
writeErr error
324
404
325
405
recycler * WriterRecycler
@@ -341,8 +421,13 @@ func (w *Writer) Reset(writer io.Writer, dictionary []byte, level compression.Le
341
421
if writer == nil {
342
422
return errNilIOWriter
343
423
}
424
+ if len (dictionary ) > maxLen {
425
+ dictionary = dictionary [len (dictionary )- maxLen :]
426
+ }
427
+
344
428
w .w = writer
345
429
w .level = level
430
+ w .dictionary = dictionary
346
431
return nil
347
432
}
348
433
@@ -436,7 +521,28 @@ func (w *Writer) write(p []byte, final bool) error {
436
521
return errOutOfMemory
437
522
}
438
523
}
439
- C .ZSTD_initCStream (w .z , C .int (w .zstdCompressionLevel ()))
524
+
525
+ e := errCode (0 )
526
+ if len (w .dictionary ) == 0 {
527
+ e = errCode (C .cgozstd_compress_start (w .z ,
528
+ (* C .uint8_t )(nil ),
529
+ (C .uint32_t )(0 ),
530
+ C .int (w .zstdCompressionLevel ()),
531
+ ))
532
+ } else {
533
+ e = errCode (C .cgozstd_compress_start (w .z ,
534
+ (* C .uint8_t )(unsafe .Pointer (& w .dictionary [0 ])),
535
+ (C .uint32_t )(len (w .dictionary )),
536
+ C .int (w .zstdCompressionLevel ()),
537
+ ))
538
+ }
539
+ if e < 0 {
540
+ w .writeErr = errZstdVersionTooSmall
541
+ return w .writeErr
542
+ } else if e != 0 {
543
+ w .writeErr = e
544
+ return w .writeErr
545
+ }
440
546
}
441
547
442
548
for (len (p ) > 0 ) || final {
0 commit comments