zstd.c 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. // SPDX-License-Identifier: GPL-2.0
  2. #include <string.h>
  3. #include "util/compress.h"
  4. #include "util/debug.h"
  5. int zstd_init(struct zstd_data *data, int level)
  6. {
  7. size_t ret;
  8. data->dstream = ZSTD_createDStream();
  9. if (data->dstream == NULL) {
  10. pr_err("Couldn't create decompression stream.\n");
  11. return -1;
  12. }
  13. ret = ZSTD_initDStream(data->dstream);
  14. if (ZSTD_isError(ret)) {
  15. pr_err("Failed to initialize decompression stream: %s\n", ZSTD_getErrorName(ret));
  16. return -1;
  17. }
  18. if (!level)
  19. return 0;
  20. data->cstream = ZSTD_createCStream();
  21. if (data->cstream == NULL) {
  22. pr_err("Couldn't create compression stream.\n");
  23. return -1;
  24. }
  25. ret = ZSTD_initCStream(data->cstream, level);
  26. if (ZSTD_isError(ret)) {
  27. pr_err("Failed to initialize compression stream: %s\n", ZSTD_getErrorName(ret));
  28. return -1;
  29. }
  30. return 0;
  31. }
  32. int zstd_fini(struct zstd_data *data)
  33. {
  34. if (data->dstream) {
  35. ZSTD_freeDStream(data->dstream);
  36. data->dstream = NULL;
  37. }
  38. if (data->cstream) {
  39. ZSTD_freeCStream(data->cstream);
  40. data->cstream = NULL;
  41. }
  42. return 0;
  43. }
  44. size_t zstd_compress_stream_to_records(struct zstd_data *data, void *dst, size_t dst_size,
  45. void *src, size_t src_size, size_t max_record_size,
  46. size_t process_header(void *record, size_t increment))
  47. {
  48. size_t ret, size, compressed = 0;
  49. ZSTD_inBuffer input = { src, src_size, 0 };
  50. ZSTD_outBuffer output;
  51. void *record;
  52. while (input.pos < input.size) {
  53. record = dst;
  54. size = process_header(record, 0);
  55. compressed += size;
  56. dst += size;
  57. dst_size -= size;
  58. output = (ZSTD_outBuffer){ dst, (dst_size > max_record_size) ?
  59. max_record_size : dst_size, 0 };
  60. ret = ZSTD_compressStream(data->cstream, &output, &input);
  61. ZSTD_flushStream(data->cstream, &output);
  62. if (ZSTD_isError(ret)) {
  63. pr_err("failed to compress %ld bytes: %s\n",
  64. (long)src_size, ZSTD_getErrorName(ret));
  65. memcpy(dst, src, src_size);
  66. return src_size;
  67. }
  68. size = output.pos;
  69. size = process_header(record, size);
  70. compressed += size;
  71. dst += size;
  72. dst_size -= size;
  73. }
  74. return compressed;
  75. }
  76. size_t zstd_decompress_stream(struct zstd_data *data, void *src, size_t src_size,
  77. void *dst, size_t dst_size)
  78. {
  79. size_t ret;
  80. ZSTD_inBuffer input = { src, src_size, 0 };
  81. ZSTD_outBuffer output = { dst, dst_size, 0 };
  82. while (input.pos < input.size) {
  83. ret = ZSTD_decompressStream(data->dstream, &output, &input);
  84. if (ZSTD_isError(ret)) {
  85. pr_err("failed to decompress (B): %zd -> %zd, dst_size %zd : %s\n",
  86. src_size, output.size, dst_size, ZSTD_getErrorName(ret));
  87. break;
  88. }
  89. output.dst = dst + output.pos;
  90. output.size = dst_size - output.pos;
  91. }
  92. return output.pos;
  93. }