cl_half.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. /*******************************************************************************
  2. * Copyright (c) 2019-2020 The Khronos Group Inc.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. ******************************************************************************/
  16. /**
  17. * This is a header-only utility library that provides OpenCL host code with
  18. * routines for converting to/from cl_half values.
  19. *
  20. * Example usage:
  21. *
  22. * #include <CL/cl_half.h>
  23. * ...
  24. * cl_half h = cl_half_from_float(0.5f, CL_HALF_RTE);
  25. * cl_float f = cl_half_to_float(h);
  26. */
  27. #ifndef OPENCL_CL_HALF_H
  28. #define OPENCL_CL_HALF_H
  29. #include <CL/cl_platform.h>
  30. #include <stdint.h>
  31. #ifdef __cplusplus
  32. extern "C" {
  33. #endif
  34. /**
  35. * Rounding mode used when converting to cl_half.
  36. */
  37. typedef enum
  38. {
  39. CL_HALF_RTE, // round to nearest even
  40. CL_HALF_RTZ, // round towards zero
  41. CL_HALF_RTP, // round towards positive infinity
  42. CL_HALF_RTN, // round towards negative infinity
  43. } cl_half_rounding_mode;
  44. /* Private utility macros. */
  45. #define CL_HALF_EXP_MASK 0x7C00
  46. #define CL_HALF_MAX_FINITE_MAG 0x7BFF
  47. /*
  48. * Utility to deal with values that overflow when converting to half precision.
  49. */
  50. static inline cl_half cl_half_handle_overflow(cl_half_rounding_mode rounding_mode,
  51. uint16_t sign)
  52. {
  53. if (rounding_mode == CL_HALF_RTZ)
  54. {
  55. // Round overflow towards zero -> largest finite number (preserving sign)
  56. return (sign << 15) | CL_HALF_MAX_FINITE_MAG;
  57. }
  58. else if (rounding_mode == CL_HALF_RTP && sign)
  59. {
  60. // Round negative overflow towards positive infinity -> most negative finite number
  61. return (1 << 15) | CL_HALF_MAX_FINITE_MAG;
  62. }
  63. else if (rounding_mode == CL_HALF_RTN && !sign)
  64. {
  65. // Round positive overflow towards negative infinity -> largest finite number
  66. return CL_HALF_MAX_FINITE_MAG;
  67. }
  68. // Overflow to infinity
  69. return (sign << 15) | CL_HALF_EXP_MASK;
  70. }
  71. /*
  72. * Utility to deal with values that underflow when converting to half precision.
  73. */
  74. static inline cl_half cl_half_handle_underflow(cl_half_rounding_mode rounding_mode,
  75. uint16_t sign)
  76. {
  77. if (rounding_mode == CL_HALF_RTP && !sign)
  78. {
  79. // Round underflow towards positive infinity -> smallest positive value
  80. return (sign << 15) | 1;
  81. }
  82. else if (rounding_mode == CL_HALF_RTN && sign)
  83. {
  84. // Round underflow towards negative infinity -> largest negative value
  85. return (sign << 15) | 1;
  86. }
  87. // Flush to zero
  88. return (sign << 15);
  89. }
  90. /**
  91. * Convert a cl_float to a cl_half.
  92. */
  93. static inline cl_half cl_half_from_float(cl_float f, cl_half_rounding_mode rounding_mode)
  94. {
  95. // Type-punning to get direct access to underlying bits
  96. union
  97. {
  98. cl_float f;
  99. uint32_t i;
  100. } f32;
  101. f32.f = f;
  102. // Extract sign bit
  103. uint16_t sign = f32.i >> 31;
  104. // Extract FP32 exponent and mantissa
  105. uint32_t f_exp = (f32.i >> (CL_FLT_MANT_DIG - 1)) & 0xFF;
  106. uint32_t f_mant = f32.i & ((1 << (CL_FLT_MANT_DIG - 1)) - 1);
  107. // Remove FP32 exponent bias
  108. int32_t exp = f_exp - CL_FLT_MAX_EXP + 1;
  109. // Add FP16 exponent bias
  110. uint16_t h_exp = (uint16_t)(exp + CL_HALF_MAX_EXP - 1);
  111. // Position of the bit that will become the FP16 mantissa LSB
  112. uint32_t lsb_pos = CL_FLT_MANT_DIG - CL_HALF_MANT_DIG;
  113. // Check for NaN / infinity
  114. if (f_exp == 0xFF)
  115. {
  116. if (f_mant)
  117. {
  118. // NaN -> propagate mantissa and silence it
  119. uint16_t h_mant = (uint16_t)(f_mant >> lsb_pos);
  120. h_mant |= 0x200;
  121. return (sign << 15) | CL_HALF_EXP_MASK | h_mant;
  122. }
  123. else
  124. {
  125. // Infinity -> zero mantissa
  126. return (sign << 15) | CL_HALF_EXP_MASK;
  127. }
  128. }
  129. // Check for zero
  130. if (!f_exp && !f_mant)
  131. {
  132. return (sign << 15);
  133. }
  134. // Check for overflow
  135. if (exp >= CL_HALF_MAX_EXP)
  136. {
  137. return cl_half_handle_overflow(rounding_mode, sign);
  138. }
  139. // Check for underflow
  140. if (exp < (CL_HALF_MIN_EXP - CL_HALF_MANT_DIG - 1))
  141. {
  142. return cl_half_handle_underflow(rounding_mode, sign);
  143. }
  144. // Check for value that will become denormal
  145. if (exp < -14)
  146. {
  147. // Denormal -> include the implicit 1 from the FP32 mantissa
  148. h_exp = 0;
  149. f_mant |= 1 << (CL_FLT_MANT_DIG - 1);
  150. // Mantissa shift amount depends on exponent
  151. lsb_pos = -exp + (CL_FLT_MANT_DIG - 25);
  152. }
  153. // Generate FP16 mantissa by shifting FP32 mantissa
  154. uint16_t h_mant = (uint16_t)(f_mant >> lsb_pos);
  155. // Check whether we need to round
  156. uint32_t halfway = 1 << (lsb_pos - 1);
  157. uint32_t mask = (halfway << 1) - 1;
  158. switch (rounding_mode)
  159. {
  160. case CL_HALF_RTE:
  161. if ((f_mant & mask) > halfway)
  162. {
  163. // More than halfway -> round up
  164. h_mant += 1;
  165. }
  166. else if ((f_mant & mask) == halfway)
  167. {
  168. // Exactly halfway -> round to nearest even
  169. if (h_mant & 0x1)
  170. h_mant += 1;
  171. }
  172. break;
  173. case CL_HALF_RTZ:
  174. // Mantissa has already been truncated -> do nothing
  175. break;
  176. case CL_HALF_RTP:
  177. if ((f_mant & mask) && !sign)
  178. {
  179. // Round positive numbers up
  180. h_mant += 1;
  181. }
  182. break;
  183. case CL_HALF_RTN:
  184. if ((f_mant & mask) && sign)
  185. {
  186. // Round negative numbers down
  187. h_mant += 1;
  188. }
  189. break;
  190. }
  191. // Check for mantissa overflow
  192. if (h_mant & 0x400)
  193. {
  194. h_exp += 1;
  195. h_mant = 0;
  196. }
  197. return (sign << 15) | (h_exp << 10) | h_mant;
  198. }
  199. /**
  200. * Convert a cl_double to a cl_half.
  201. */
  202. static inline cl_half cl_half_from_double(cl_double d, cl_half_rounding_mode rounding_mode)
  203. {
  204. // Type-punning to get direct access to underlying bits
  205. union
  206. {
  207. cl_double d;
  208. uint64_t i;
  209. } f64;
  210. f64.d = d;
  211. // Extract sign bit
  212. uint16_t sign = f64.i >> 63;
  213. // Extract FP64 exponent and mantissa
  214. uint64_t d_exp = (f64.i >> (CL_DBL_MANT_DIG - 1)) & 0x7FF;
  215. uint64_t d_mant = f64.i & (((uint64_t)1 << (CL_DBL_MANT_DIG - 1)) - 1);
  216. // Remove FP64 exponent bias
  217. int64_t exp = d_exp - CL_DBL_MAX_EXP + 1;
  218. // Add FP16 exponent bias
  219. uint16_t h_exp = (uint16_t)(exp + CL_HALF_MAX_EXP - 1);
  220. // Position of the bit that will become the FP16 mantissa LSB
  221. uint32_t lsb_pos = CL_DBL_MANT_DIG - CL_HALF_MANT_DIG;
  222. // Check for NaN / infinity
  223. if (d_exp == 0x7FF)
  224. {
  225. if (d_mant)
  226. {
  227. // NaN -> propagate mantissa and silence it
  228. uint16_t h_mant = (uint16_t)(d_mant >> lsb_pos);
  229. h_mant |= 0x200;
  230. return (sign << 15) | CL_HALF_EXP_MASK | h_mant;
  231. }
  232. else
  233. {
  234. // Infinity -> zero mantissa
  235. return (sign << 15) | CL_HALF_EXP_MASK;
  236. }
  237. }
  238. // Check for zero
  239. if (!d_exp && !d_mant)
  240. {
  241. return (sign << 15);
  242. }
  243. // Check for overflow
  244. if (exp >= CL_HALF_MAX_EXP)
  245. {
  246. return cl_half_handle_overflow(rounding_mode, sign);
  247. }
  248. // Check for underflow
  249. if (exp < (CL_HALF_MIN_EXP - CL_HALF_MANT_DIG - 1))
  250. {
  251. return cl_half_handle_underflow(rounding_mode, sign);
  252. }
  253. // Check for value that will become denormal
  254. if (exp < -14)
  255. {
  256. // Include the implicit 1 from the FP64 mantissa
  257. h_exp = 0;
  258. d_mant |= (uint64_t)1 << (CL_DBL_MANT_DIG - 1);
  259. // Mantissa shift amount depends on exponent
  260. lsb_pos = (uint32_t)(-exp + (CL_DBL_MANT_DIG - 25));
  261. }
  262. // Generate FP16 mantissa by shifting FP64 mantissa
  263. uint16_t h_mant = (uint16_t)(d_mant >> lsb_pos);
  264. // Check whether we need to round
  265. uint64_t halfway = (uint64_t)1 << (lsb_pos - 1);
  266. uint64_t mask = (halfway << 1) - 1;
  267. switch (rounding_mode)
  268. {
  269. case CL_HALF_RTE:
  270. if ((d_mant & mask) > halfway)
  271. {
  272. // More than halfway -> round up
  273. h_mant += 1;
  274. }
  275. else if ((d_mant & mask) == halfway)
  276. {
  277. // Exactly halfway -> round to nearest even
  278. if (h_mant & 0x1)
  279. h_mant += 1;
  280. }
  281. break;
  282. case CL_HALF_RTZ:
  283. // Mantissa has already been truncated -> do nothing
  284. break;
  285. case CL_HALF_RTP:
  286. if ((d_mant & mask) && !sign)
  287. {
  288. // Round positive numbers up
  289. h_mant += 1;
  290. }
  291. break;
  292. case CL_HALF_RTN:
  293. if ((d_mant & mask) && sign)
  294. {
  295. // Round negative numbers down
  296. h_mant += 1;
  297. }
  298. break;
  299. }
  300. // Check for mantissa overflow
  301. if (h_mant & 0x400)
  302. {
  303. h_exp += 1;
  304. h_mant = 0;
  305. }
  306. return (sign << 15) | (h_exp << 10) | h_mant;
  307. }
  308. /**
  309. * Convert a cl_half to a cl_float.
  310. */
  311. static inline cl_float cl_half_to_float(cl_half h)
  312. {
  313. // Type-punning to get direct access to underlying bits
  314. union
  315. {
  316. cl_float f;
  317. uint32_t i;
  318. } f32;
  319. // Extract sign bit
  320. uint16_t sign = h >> 15;
  321. // Extract FP16 exponent and mantissa
  322. uint16_t h_exp = (h >> (CL_HALF_MANT_DIG - 1)) & 0x1F;
  323. uint16_t h_mant = h & 0x3FF;
  324. // Remove FP16 exponent bias
  325. int32_t exp = h_exp - CL_HALF_MAX_EXP + 1;
  326. // Add FP32 exponent bias
  327. uint32_t f_exp = exp + CL_FLT_MAX_EXP - 1;
  328. // Check for NaN / infinity
  329. if (h_exp == 0x1F)
  330. {
  331. if (h_mant)
  332. {
  333. // NaN -> propagate mantissa and silence it
  334. uint32_t f_mant = h_mant << (CL_FLT_MANT_DIG - CL_HALF_MANT_DIG);
  335. f_mant |= 0x400000;
  336. f32.i = (sign << 31) | 0x7F800000 | f_mant;
  337. return f32.f;
  338. }
  339. else
  340. {
  341. // Infinity -> zero mantissa
  342. f32.i = (sign << 31) | 0x7F800000;
  343. return f32.f;
  344. }
  345. }
  346. // Check for zero / denormal
  347. if (h_exp == 0)
  348. {
  349. if (h_mant == 0)
  350. {
  351. // Zero -> zero exponent
  352. f_exp = 0;
  353. }
  354. else
  355. {
  356. // Denormal -> normalize it
  357. // - Shift mantissa to make most-significant 1 implicit
  358. // - Adjust exponent accordingly
  359. uint32_t shift = 0;
  360. while ((h_mant & 0x400) == 0)
  361. {
  362. h_mant <<= 1;
  363. shift++;
  364. }
  365. h_mant &= 0x3FF;
  366. f_exp -= shift - 1;
  367. }
  368. }
  369. f32.i = (sign << 31) | (f_exp << 23) | (h_mant << 13);
  370. return f32.f;
  371. }
  372. #undef CL_HALF_EXP_MASK
  373. #undef CL_HALF_MAX_FINITE_MAG
  374. #ifdef __cplusplus
  375. }
  376. #endif
  377. #endif /* OPENCL_CL_HALF_H */