zstd.c 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. // SPDX-License-Identifier: GPL-2.0-only
  2. /*
  3. * Cryptographic API.
  4. *
  5. * Copyright (c) 2017-present, Facebook, Inc.
  6. */
  7. #include <linux/crypto.h>
  8. #include <linux/init.h>
  9. #include <linux/interrupt.h>
  10. #include <linux/mm.h>
  11. #include <linux/module.h>
  12. #include <linux/net.h>
  13. #include <linux/vmalloc.h>
  14. #include <linux/zstd.h>
  15. #include <crypto/internal/scompress.h>
  16. #define ZSTD_DEF_LEVEL 3
  17. struct zstd_ctx {
  18. ZSTD_CCtx *cctx;
  19. ZSTD_DCtx *dctx;
  20. void *cwksp;
  21. void *dwksp;
  22. };
  23. static ZSTD_parameters zstd_params(void)
  24. {
  25. return ZSTD_getParams(ZSTD_DEF_LEVEL, 0, 0);
  26. }
  27. static int zstd_comp_init(struct zstd_ctx *ctx)
  28. {
  29. int ret = 0;
  30. const ZSTD_parameters params = zstd_params();
  31. const size_t wksp_size = ZSTD_CCtxWorkspaceBound(params.cParams);
  32. ctx->cwksp = vzalloc(wksp_size);
  33. if (!ctx->cwksp) {
  34. ret = -ENOMEM;
  35. goto out;
  36. }
  37. ctx->cctx = ZSTD_initCCtx(ctx->cwksp, wksp_size);
  38. if (!ctx->cctx) {
  39. ret = -EINVAL;
  40. goto out_free;
  41. }
  42. out:
  43. return ret;
  44. out_free:
  45. vfree(ctx->cwksp);
  46. goto out;
  47. }
  48. static int zstd_decomp_init(struct zstd_ctx *ctx)
  49. {
  50. int ret = 0;
  51. const size_t wksp_size = ZSTD_DCtxWorkspaceBound();
  52. ctx->dwksp = vzalloc(wksp_size);
  53. if (!ctx->dwksp) {
  54. ret = -ENOMEM;
  55. goto out;
  56. }
  57. ctx->dctx = ZSTD_initDCtx(ctx->dwksp, wksp_size);
  58. if (!ctx->dctx) {
  59. ret = -EINVAL;
  60. goto out_free;
  61. }
  62. out:
  63. return ret;
  64. out_free:
  65. vfree(ctx->dwksp);
  66. goto out;
  67. }
  68. static void zstd_comp_exit(struct zstd_ctx *ctx)
  69. {
  70. vfree(ctx->cwksp);
  71. ctx->cwksp = NULL;
  72. ctx->cctx = NULL;
  73. }
  74. static void zstd_decomp_exit(struct zstd_ctx *ctx)
  75. {
  76. vfree(ctx->dwksp);
  77. ctx->dwksp = NULL;
  78. ctx->dctx = NULL;
  79. }
  80. static int __zstd_init(void *ctx)
  81. {
  82. int ret;
  83. ret = zstd_comp_init(ctx);
  84. if (ret)
  85. return ret;
  86. ret = zstd_decomp_init(ctx);
  87. if (ret)
  88. zstd_comp_exit(ctx);
  89. return ret;
  90. }
  91. static void *zstd_alloc_ctx(struct crypto_scomp *tfm)
  92. {
  93. int ret;
  94. struct zstd_ctx *ctx;
  95. ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
  96. if (!ctx)
  97. return ERR_PTR(-ENOMEM);
  98. ret = __zstd_init(ctx);
  99. if (ret) {
  100. kfree(ctx);
  101. return ERR_PTR(ret);
  102. }
  103. return ctx;
  104. }
  105. static int zstd_init(struct crypto_tfm *tfm)
  106. {
  107. struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
  108. return __zstd_init(ctx);
  109. }
  110. static void __zstd_exit(void *ctx)
  111. {
  112. zstd_comp_exit(ctx);
  113. zstd_decomp_exit(ctx);
  114. }
  115. static void zstd_free_ctx(struct crypto_scomp *tfm, void *ctx)
  116. {
  117. __zstd_exit(ctx);
  118. kfree_sensitive(ctx);
  119. }
  120. static void zstd_exit(struct crypto_tfm *tfm)
  121. {
  122. struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
  123. __zstd_exit(ctx);
  124. }
  125. static int __zstd_compress(const u8 *src, unsigned int slen,
  126. u8 *dst, unsigned int *dlen, void *ctx)
  127. {
  128. size_t out_len;
  129. struct zstd_ctx *zctx = ctx;
  130. const ZSTD_parameters params = zstd_params();
  131. out_len = ZSTD_compressCCtx(zctx->cctx, dst, *dlen, src, slen, params);
  132. if (ZSTD_isError(out_len))
  133. return -EINVAL;
  134. *dlen = out_len;
  135. return 0;
  136. }
  137. static int zstd_compress(struct crypto_tfm *tfm, const u8 *src,
  138. unsigned int slen, u8 *dst, unsigned int *dlen)
  139. {
  140. struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
  141. return __zstd_compress(src, slen, dst, dlen, ctx);
  142. }
  143. static int zstd_scompress(struct crypto_scomp *tfm, const u8 *src,
  144. unsigned int slen, u8 *dst, unsigned int *dlen,
  145. void *ctx)
  146. {
  147. return __zstd_compress(src, slen, dst, dlen, ctx);
  148. }
  149. static int __zstd_decompress(const u8 *src, unsigned int slen,
  150. u8 *dst, unsigned int *dlen, void *ctx)
  151. {
  152. size_t out_len;
  153. struct zstd_ctx *zctx = ctx;
  154. out_len = ZSTD_decompressDCtx(zctx->dctx, dst, *dlen, src, slen);
  155. if (ZSTD_isError(out_len))
  156. return -EINVAL;
  157. *dlen = out_len;
  158. return 0;
  159. }
  160. static int zstd_decompress(struct crypto_tfm *tfm, const u8 *src,
  161. unsigned int slen, u8 *dst, unsigned int *dlen)
  162. {
  163. struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
  164. return __zstd_decompress(src, slen, dst, dlen, ctx);
  165. }
  166. static int zstd_sdecompress(struct crypto_scomp *tfm, const u8 *src,
  167. unsigned int slen, u8 *dst, unsigned int *dlen,
  168. void *ctx)
  169. {
  170. return __zstd_decompress(src, slen, dst, dlen, ctx);
  171. }
  172. static struct crypto_alg alg = {
  173. .cra_name = "zstd",
  174. .cra_driver_name = "zstd-generic",
  175. .cra_flags = CRYPTO_ALG_TYPE_COMPRESS,
  176. .cra_ctxsize = sizeof(struct zstd_ctx),
  177. .cra_module = THIS_MODULE,
  178. .cra_init = zstd_init,
  179. .cra_exit = zstd_exit,
  180. .cra_u = { .compress = {
  181. .coa_compress = zstd_compress,
  182. .coa_decompress = zstd_decompress } }
  183. };
  184. static struct scomp_alg scomp = {
  185. .alloc_ctx = zstd_alloc_ctx,
  186. .free_ctx = zstd_free_ctx,
  187. .compress = zstd_scompress,
  188. .decompress = zstd_sdecompress,
  189. .base = {
  190. .cra_name = "zstd",
  191. .cra_driver_name = "zstd-scomp",
  192. .cra_module = THIS_MODULE,
  193. }
  194. };
  195. static int __init zstd_mod_init(void)
  196. {
  197. int ret;
  198. ret = crypto_register_alg(&alg);
  199. if (ret)
  200. return ret;
  201. ret = crypto_register_scomp(&scomp);
  202. if (ret)
  203. crypto_unregister_alg(&alg);
  204. return ret;
  205. }
  206. static void __exit zstd_mod_fini(void)
  207. {
  208. crypto_unregister_alg(&alg);
  209. crypto_unregister_scomp(&scomp);
  210. }
  211. subsys_initcall(zstd_mod_init);
  212. module_exit(zstd_mod_fini);
  213. MODULE_LICENSE("GPL");
  214. MODULE_DESCRIPTION("Zstd Compression Algorithm");
  215. MODULE_ALIAS_CRYPTO("zstd");