entropy_common.c 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. // SPDX-License-Identifier: (GPL-2.0 or BSD-2-Clause)
  2. /*
  3. * Common functions of New Generation Entropy library
  4. * Copyright (C) 2016, Yann Collet.
  5. *
  6. * You can contact the author at :
  7. * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy
  8. */
  9. /* *************************************
  10. * Dependencies
  11. ***************************************/
  12. #include "error_private.h" /* ERR_*, ERROR */
  13. #include "fse.h"
  14. #include "huf.h"
  15. #include "mem.h"
  16. /*=== Version ===*/
  17. unsigned FSE_versionNumber(void) { return FSE_VERSION_NUMBER; }
  18. /*=== Error Management ===*/
  19. unsigned FSE_isError(size_t code) { return ERR_isError(code); }
  20. unsigned HUF_isError(size_t code) { return ERR_isError(code); }
  21. /*-**************************************************************
  22. * FSE NCount encoding-decoding
  23. ****************************************************************/
  24. size_t FSE_readNCount(short *normalizedCounter, unsigned *maxSVPtr, unsigned *tableLogPtr, const void *headerBuffer, size_t hbSize)
  25. {
  26. const BYTE *const istart = (const BYTE *)headerBuffer;
  27. const BYTE *const iend = istart + hbSize;
  28. const BYTE *ip = istart;
  29. int nbBits;
  30. int remaining;
  31. int threshold;
  32. U32 bitStream;
  33. int bitCount;
  34. unsigned charnum = 0;
  35. int previous0 = 0;
  36. if (hbSize < 4)
  37. return ERROR(srcSize_wrong);
  38. bitStream = ZSTD_readLE32(ip);
  39. nbBits = (bitStream & 0xF) + FSE_MIN_TABLELOG; /* extract tableLog */
  40. if (nbBits > FSE_TABLELOG_ABSOLUTE_MAX)
  41. return ERROR(tableLog_tooLarge);
  42. bitStream >>= 4;
  43. bitCount = 4;
  44. *tableLogPtr = nbBits;
  45. remaining = (1 << nbBits) + 1;
  46. threshold = 1 << nbBits;
  47. nbBits++;
  48. while ((remaining > 1) & (charnum <= *maxSVPtr)) {
  49. if (previous0) {
  50. unsigned n0 = charnum;
  51. while ((bitStream & 0xFFFF) == 0xFFFF) {
  52. n0 += 24;
  53. if (ip < iend - 5) {
  54. ip += 2;
  55. bitStream = ZSTD_readLE32(ip) >> bitCount;
  56. } else {
  57. bitStream >>= 16;
  58. bitCount += 16;
  59. }
  60. }
  61. while ((bitStream & 3) == 3) {
  62. n0 += 3;
  63. bitStream >>= 2;
  64. bitCount += 2;
  65. }
  66. n0 += bitStream & 3;
  67. bitCount += 2;
  68. if (n0 > *maxSVPtr)
  69. return ERROR(maxSymbolValue_tooSmall);
  70. while (charnum < n0)
  71. normalizedCounter[charnum++] = 0;
  72. if ((ip <= iend - 7) || (ip + (bitCount >> 3) <= iend - 4)) {
  73. ip += bitCount >> 3;
  74. bitCount &= 7;
  75. bitStream = ZSTD_readLE32(ip) >> bitCount;
  76. } else {
  77. bitStream >>= 2;
  78. }
  79. }
  80. {
  81. int const max = (2 * threshold - 1) - remaining;
  82. int count;
  83. if ((bitStream & (threshold - 1)) < (U32)max) {
  84. count = bitStream & (threshold - 1);
  85. bitCount += nbBits - 1;
  86. } else {
  87. count = bitStream & (2 * threshold - 1);
  88. if (count >= threshold)
  89. count -= max;
  90. bitCount += nbBits;
  91. }
  92. count--; /* extra accuracy */
  93. remaining -= count < 0 ? -count : count; /* -1 means +1 */
  94. normalizedCounter[charnum++] = (short)count;
  95. previous0 = !count;
  96. while (remaining < threshold) {
  97. nbBits--;
  98. threshold >>= 1;
  99. }
  100. if ((ip <= iend - 7) || (ip + (bitCount >> 3) <= iend - 4)) {
  101. ip += bitCount >> 3;
  102. bitCount &= 7;
  103. } else {
  104. bitCount -= (int)(8 * (iend - 4 - ip));
  105. ip = iend - 4;
  106. }
  107. bitStream = ZSTD_readLE32(ip) >> (bitCount & 31);
  108. }
  109. } /* while ((remaining>1) & (charnum<=*maxSVPtr)) */
  110. if (remaining != 1)
  111. return ERROR(corruption_detected);
  112. if (bitCount > 32)
  113. return ERROR(corruption_detected);
  114. *maxSVPtr = charnum - 1;
  115. ip += (bitCount + 7) >> 3;
  116. return ip - istart;
  117. }
  118. /*! HUF_readStats() :
  119. Read compact Huffman tree, saved by HUF_writeCTable().
  120. `huffWeight` is destination buffer.
  121. `rankStats` is assumed to be a table of at least HUF_TABLELOG_MAX U32.
  122. @return : size read from `src` , or an error Code .
  123. Note : Needed by HUF_readCTable() and HUF_readDTableX?() .
  124. */
  125. size_t HUF_readStats_wksp(BYTE *huffWeight, size_t hwSize, U32 *rankStats, U32 *nbSymbolsPtr, U32 *tableLogPtr, const void *src, size_t srcSize, void *workspace, size_t workspaceSize)
  126. {
  127. U32 weightTotal;
  128. const BYTE *ip = (const BYTE *)src;
  129. size_t iSize;
  130. size_t oSize;
  131. if (!srcSize)
  132. return ERROR(srcSize_wrong);
  133. iSize = ip[0];
  134. /* memset(huffWeight, 0, hwSize); */ /* is not necessary, even though some analyzer complain ... */
  135. if (iSize >= 128) { /* special header */
  136. oSize = iSize - 127;
  137. iSize = ((oSize + 1) / 2);
  138. if (iSize + 1 > srcSize)
  139. return ERROR(srcSize_wrong);
  140. if (oSize >= hwSize)
  141. return ERROR(corruption_detected);
  142. ip += 1;
  143. {
  144. U32 n;
  145. for (n = 0; n < oSize; n += 2) {
  146. huffWeight[n] = ip[n / 2] >> 4;
  147. huffWeight[n + 1] = ip[n / 2] & 15;
  148. }
  149. }
  150. } else { /* header compressed with FSE (normal case) */
  151. if (iSize + 1 > srcSize)
  152. return ERROR(srcSize_wrong);
  153. oSize = FSE_decompress_wksp(huffWeight, hwSize - 1, ip + 1, iSize, 6, workspace, workspaceSize); /* max (hwSize-1) values decoded, as last one is implied */
  154. if (FSE_isError(oSize))
  155. return oSize;
  156. }
  157. /* collect weight stats */
  158. memset(rankStats, 0, (HUF_TABLELOG_MAX + 1) * sizeof(U32));
  159. weightTotal = 0;
  160. {
  161. U32 n;
  162. for (n = 0; n < oSize; n++) {
  163. if (huffWeight[n] >= HUF_TABLELOG_MAX)
  164. return ERROR(corruption_detected);
  165. rankStats[huffWeight[n]]++;
  166. weightTotal += (1 << huffWeight[n]) >> 1;
  167. }
  168. }
  169. if (weightTotal == 0)
  170. return ERROR(corruption_detected);
  171. /* get last non-null symbol weight (implied, total must be 2^n) */
  172. {
  173. U32 const tableLog = BIT_highbit32(weightTotal) + 1;
  174. if (tableLog > HUF_TABLELOG_MAX)
  175. return ERROR(corruption_detected);
  176. *tableLogPtr = tableLog;
  177. /* determine last weight */
  178. {
  179. U32 const total = 1 << tableLog;
  180. U32 const rest = total - weightTotal;
  181. U32 const verif = 1 << BIT_highbit32(rest);
  182. U32 const lastWeight = BIT_highbit32(rest) + 1;
  183. if (verif != rest)
  184. return ERROR(corruption_detected); /* last value must be a clean power of 2 */
  185. huffWeight[oSize] = (BYTE)lastWeight;
  186. rankStats[lastWeight]++;
  187. }
  188. }
  189. /* check tree construction validity */
  190. if ((rankStats[1] < 2) || (rankStats[1] & 1))
  191. return ERROR(corruption_detected); /* by construction : at least 2 elts of rank 1, must be even */
  192. /* results */
  193. *nbSymbolsPtr = (U32)(oSize + 1);
  194. return iSize + 1;
  195. }