shl_c908.h 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. /*
  2. * Copyright (C) 2016-2022 T-Head Semiconductor Co., Ltd. All rights reserved.
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* CSI-NN2 version 2.0.x */
  19. #ifndef INCLUDE_SHL_C908_H_
  20. #define INCLUDE_SHL_C908_H_
  21. #include "csi_nn.h"
  22. #include "shl_gref.h"
  23. #include "shl_ref.h"
  24. #include "shl_thead_rvv.h"
  25. /*********************************** initialization ***********************************/
  26. int shl_c908_conv2d_init_fp32(struct csinn_tensor *input, struct csinn_tensor *output,
  27. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  28. struct csinn_conv2d_params *params);
  29. int shl_c908_conv2d_init_fp16(struct csinn_tensor *input, struct csinn_tensor *output,
  30. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  31. struct csinn_conv2d_params *params);
  32. int shl_c908_conv2d_init_int8(struct csinn_tensor *input, struct csinn_tensor *output,
  33. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  34. struct csinn_conv2d_params *params);
  35. int shl_c908_depthwise_conv2d_init_fp32(struct csinn_tensor *input, struct csinn_tensor *output,
  36. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  37. struct csinn_conv2d_params *params);
  38. int shl_c908_depthwise_conv2d_init_fp16(struct csinn_tensor *input, struct csinn_tensor *output,
  39. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  40. struct csinn_conv2d_params *params);
  41. int shl_c908_depthwise_conv2d_init_int8(struct csinn_tensor *input, struct csinn_tensor *output,
  42. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  43. struct csinn_conv2d_params *params);
  44. int shl_c908_avgpool2d_init_fp32(struct csinn_tensor *input, struct csinn_tensor *output,
  45. struct csinn_pool_params *params);
  46. int shl_c908_avgpool2d_init_fp16(struct csinn_tensor *input, struct csinn_tensor *output,
  47. struct csinn_pool_params *params);
  48. int shl_c908_avgpool2d_init_int8(struct csinn_tensor *input, struct csinn_tensor *output,
  49. struct csinn_pool_params *params);
  50. int shl_c908_maxpool2d_init_fp32(struct csinn_tensor *input, struct csinn_tensor *output,
  51. struct csinn_pool_params *params);
  52. int shl_c908_maxpool2d_init_fp16(struct csinn_tensor *input, struct csinn_tensor *output,
  53. struct csinn_pool_params *params);
  54. int shl_c908_maxpool2d_init_int8(struct csinn_tensor *input, struct csinn_tensor *output,
  55. struct csinn_pool_params *params);
  56. int shl_c908_fullyconnected_init(struct csinn_tensor *input, struct csinn_tensor *output,
  57. struct csinn_tensor *weights, struct csinn_tensor *bias,
  58. struct csinn_fc_params *params);
  59. /************************************ convolution *********************************/
  60. /*********************************** im2col + gemm ********************************/
  61. void shl_c908_conv_im2col_gemm_reorder_kernel_fp32(struct csinn_tensor *kernel,
  62. struct csinn_conv2d_params *params);
  63. void shl_c908_conv_im2col_gemm_reorder_kernel_fp16(struct csinn_tensor *kernel,
  64. struct csinn_conv2d_params *params);
  65. void shl_c908_conv_im2col_gemm_reorder_kernel_int8(struct csinn_tensor *kernel,
  66. struct csinn_conv2d_params *params);
  67. int shl_c908_conv_im2col_gemm_fp32(struct csinn_tensor *input, struct csinn_tensor *output,
  68. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  69. struct csinn_conv2d_params *params);
  70. int shl_c908_conv_im2col_gemm_fp16(struct csinn_tensor *input, struct csinn_tensor *output,
  71. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  72. struct csinn_conv2d_params *params);
  73. int shl_c908_conv_im2col_gemm_int8(struct csinn_tensor *input, struct csinn_tensor *output,
  74. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  75. struct csinn_conv2d_params *params);
  76. void shl_c908_conv_im2col_gemm_reorder_kernel_packn_fp32(struct csinn_tensor *kernel,
  77. struct csinn_conv2d_params *params);
  78. void shl_c908_conv_im2col_gemm_reorder_kernel_packn_fp16(struct csinn_tensor *kernel,
  79. struct csinn_conv2d_params *params);
  80. void shl_c908_conv_im2col_gemm_reorder_kernel_packn_int8(struct csinn_tensor *kernel,
  81. struct csinn_conv2d_params *params);
  82. int shl_c908_conv_im2col_gemm_packn_fp32(struct csinn_tensor *input, struct csinn_tensor *output,
  83. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  84. struct csinn_conv2d_params *params);
  85. int shl_c908_conv_im2col_gemm_packn_fp16(struct csinn_tensor *input, struct csinn_tensor *output,
  86. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  87. struct csinn_conv2d_params *params);
  88. int shl_c908_conv_im2col_gemm_packn_int8(struct csinn_tensor *input, struct csinn_tensor *output,
  89. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  90. struct csinn_conv2d_params *params);
  91. void shl_c908_conv_im2col_gemm_reorder_kernel_pack1ton_fp32(struct csinn_tensor *kernel,
  92. struct csinn_conv2d_params *params);
  93. void shl_c908_conv_im2col_gemm_reorder_kernel_pack1ton_fp16(struct csinn_tensor *kernel,
  94. struct csinn_conv2d_params *params);
  95. void shl_c908_conv_im2col_gemm_reorder_kernel_pack1ton_int8(struct csinn_tensor *kernel,
  96. struct csinn_conv2d_params *params);
  97. int shl_c908_conv_im2col_gemm_pack1ton_fp32(struct csinn_tensor *input, struct csinn_tensor *output,
  98. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  99. struct csinn_conv2d_params *params);
  100. int shl_c908_conv_im2col_gemm_pack1ton_fp16(struct csinn_tensor *input, struct csinn_tensor *output,
  101. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  102. struct csinn_conv2d_params *params);
  103. int shl_c908_conv_im2col_gemm_pack1ton_int8(struct csinn_tensor *input, struct csinn_tensor *output,
  104. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  105. struct csinn_conv2d_params *params);
  106. void shl_c908_conv_im2col_gemm_reorder_kernel_packnto1_fp32(struct csinn_tensor *kernel,
  107. struct csinn_conv2d_params *params);
  108. void shl_c908_conv_im2col_gemm_reorder_kernel_packnto1_fp16(struct csinn_tensor *kernel,
  109. struct csinn_conv2d_params *params);
  110. void shl_c908_conv_im2col_gemm_reorder_kernel_packnto1_int8(struct csinn_tensor *kernel,
  111. struct csinn_conv2d_params *params);
  112. int shl_c908_conv_im2col_gemm_packnto1_fp32(struct csinn_tensor *input, struct csinn_tensor *output,
  113. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  114. struct csinn_conv2d_params *params);
  115. int shl_c908_conv_im2col_gemm_packnto1_fp16(struct csinn_tensor *input, struct csinn_tensor *output,
  116. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  117. struct csinn_conv2d_params *params);
  118. int shl_c908_conv_im2col_gemm_packnto1_int8(struct csinn_tensor *input, struct csinn_tensor *output,
  119. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  120. struct csinn_conv2d_params *params);
  121. /******************************** conv2d1x1s1 + gemm ******************************/
  122. void shl_c908_conv1x1s1_gemm_reorder_kernel_fp32(struct csinn_tensor *kernel,
  123. struct csinn_conv2d_params *params);
  124. void shl_c908_conv1x1s1_gemm_reorder_kernel_fp16(struct csinn_tensor *kernel,
  125. struct csinn_conv2d_params *params);
  126. void shl_c908_conv1x1s1_gemm_reorder_kernel_int8(struct csinn_tensor *kernel,
  127. struct csinn_conv2d_params *params);
  128. int shl_c908_conv1x1s1_gemm_fp32(struct csinn_tensor *input, struct csinn_tensor *output,
  129. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  130. struct csinn_conv2d_params *params);
  131. int shl_c908_conv1x1s1_gemm_fp16(struct csinn_tensor *input, struct csinn_tensor *output,
  132. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  133. struct csinn_conv2d_params *params);
  134. int shl_c908_conv1x1s1_gemm_int8(struct csinn_tensor *input, struct csinn_tensor *output,
  135. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  136. struct csinn_conv2d_params *params);
  137. void shl_c908_conv1x1s1_gemm_reorder_kernel_packn_fp32(struct csinn_tensor *kernel,
  138. struct csinn_conv2d_params *params);
  139. void shl_c908_conv1x1s1_gemm_reorder_kernel_packn_fp16(struct csinn_tensor *kernel,
  140. struct csinn_conv2d_params *params);
  141. void shl_c908_conv1x1s1_gemm_reorder_kernel_packn_int8(struct csinn_tensor *kernel,
  142. struct csinn_conv2d_params *params);
  143. int shl_c908_conv1x1s1_gemm_packn_fp32(struct csinn_tensor *input, struct csinn_tensor *output,
  144. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  145. struct csinn_conv2d_params *params);
  146. int shl_c908_conv1x1s1_gemm_packn_fp16(struct csinn_tensor *input, struct csinn_tensor *output,
  147. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  148. struct csinn_conv2d_params *params);
  149. int shl_c908_conv1x1s1_gemm_packn_int8(struct csinn_tensor *input, struct csinn_tensor *output,
  150. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  151. struct csinn_conv2d_params *params);
  152. void shl_c908_conv1x1s1_gemm_reorder_kernel_pack1ton_fp32(struct csinn_tensor *kernel,
  153. struct csinn_conv2d_params *params);
  154. void shl_c908_conv1x1s1_gemm_reorder_kernel_pack1ton_fp16(struct csinn_tensor *kernel,
  155. struct csinn_conv2d_params *params);
  156. void shl_c908_conv1x1s1_gemm_reorder_kernel_pack1ton_int8(struct csinn_tensor *kernel,
  157. struct csinn_conv2d_params *params);
  158. int shl_c908_conv1x1s1_gemm_pack1ton_fp32(struct csinn_tensor *input, struct csinn_tensor *output,
  159. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  160. struct csinn_conv2d_params *params);
  161. int shl_c908_conv1x1s1_gemm_pack1ton_fp16(struct csinn_tensor *input, struct csinn_tensor *output,
  162. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  163. struct csinn_conv2d_params *params);
  164. int shl_c908_conv1x1s1_gemm_pack1ton_int8(struct csinn_tensor *input, struct csinn_tensor *output,
  165. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  166. struct csinn_conv2d_params *params);
  167. void shl_c908_conv1x1s1_gemm_reorder_kernel_packnto1_fp32(struct csinn_tensor *kernel,
  168. struct csinn_conv2d_params *params);
  169. void shl_c908_conv1x1s1_gemm_reorder_kernel_packnto1_fp16(struct csinn_tensor *kernel,
  170. struct csinn_conv2d_params *params);
  171. void shl_c908_conv1x1s1_gemm_reorder_kernel_packnto1_int8(struct csinn_tensor *kernel,
  172. struct csinn_conv2d_params *params);
  173. int shl_c908_conv1x1s1_gemm_packnto1_fp32(struct csinn_tensor *input, struct csinn_tensor *output,
  174. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  175. struct csinn_conv2d_params *params);
  176. int shl_c908_conv1x1s1_gemm_packnto1_fp16(struct csinn_tensor *input, struct csinn_tensor *output,
  177. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  178. struct csinn_conv2d_params *params);
  179. int shl_c908_conv1x1s1_gemm_packnto1_int8(struct csinn_tensor *input, struct csinn_tensor *output,
  180. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  181. struct csinn_conv2d_params *params);
  182. /*********************************** winograd ***********************************/
  183. void shl_c908_wg_b6f3s1_trans_kernel_pack8_fp32(struct csinn_tensor *src_kernel,
  184. struct csinn_tensor *dst_kernel);
  185. void shl_c908_wg_b6f3s1_trans_kernel_pack8_fp16(struct csinn_tensor *src_kernel,
  186. struct csinn_tensor *dst_kernel);
  187. void shl_c908_wg_b6f3s1_trans_kernel_pack16_fp16(struct csinn_tensor *src_kernel,
  188. struct csinn_tensor *dst_kernel);
  189. void shl_c908_wg_b4f3s1_trans_kernel_pack8_fp32(struct csinn_tensor *src_kernel,
  190. struct csinn_tensor *dst_kernel);
  191. void shl_c908_wg_b4f3s1_trans_kernel_pack8_fp16(struct csinn_tensor *src_kernel,
  192. struct csinn_tensor *dst_kernel);
  193. void shl_c908_wg_b4f3s1_trans_kernel_pack16_fp16(struct csinn_tensor *src_kernel,
  194. struct csinn_tensor *dst_kernel);
  195. void shl_c908_wg_b4f3s1_trans_kernel_pack8_int8(struct csinn_tensor *src_kernel,
  196. struct csinn_tensor *dst_kernel);
  197. int shl_c908_wg_b6f3s1_pack8_fp32(struct csinn_tensor *input, struct csinn_tensor *output,
  198. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  199. struct csinn_conv2d_params *params);
  200. int shl_c908_wg_b6f3s1_pack8_fp16(struct csinn_tensor *input, struct csinn_tensor *output,
  201. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  202. struct csinn_conv2d_params *params);
  203. int shl_c908_wg_b6f3s1_pack16_fp16(struct csinn_tensor *input, struct csinn_tensor *output,
  204. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  205. struct csinn_conv2d_params *params);
  206. int shl_c908_wg_b4f3s1_pack8_fp32(struct csinn_tensor *input, struct csinn_tensor *output,
  207. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  208. struct csinn_conv2d_params *params);
  209. int shl_c908_wg_b4f3s1_pack8_fp16(struct csinn_tensor *input, struct csinn_tensor *output,
  210. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  211. struct csinn_conv2d_params *params);
  212. int shl_c908_wg_b4f3s1_pack16_fp16(struct csinn_tensor *input, struct csinn_tensor *output,
  213. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  214. struct csinn_conv2d_params *params);
  215. int shl_c908_wg_b4f3s1_pack8_int8(struct csinn_tensor *input, struct csinn_tensor *output,
  216. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  217. struct csinn_conv2d_params *params);
  218. void shl_c908_ncxhwx_wg_b6f3s1_trans_kernel_packn_fp32(struct csinn_tensor *src_kernel,
  219. struct csinn_tensor *dst_kernel);
  220. void shl_c908_ncxhwx_wg_b6f3s1_trans_kernel_packn_fp16(struct csinn_tensor *src_kernel,
  221. struct csinn_tensor *dst_kernel);
  222. int shl_c908_ncxhwx_wg_b6f3s1_packn_fp32(struct csinn_tensor *input, struct csinn_tensor *output,
  223. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  224. struct csinn_conv2d_params *params);
  225. int shl_c908_ncxhwx_wg_b6f3s1_packn_fp16(struct csinn_tensor *input, struct csinn_tensor *output,
  226. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  227. struct csinn_conv2d_params *params);
  228. void shl_c908_ncxhwx_wg_b4f3s1_trans_kernel_packn_fp32(struct csinn_tensor *src_kernel,
  229. struct csinn_tensor *dst_kernel);
  230. void shl_c908_ncxhwx_wg_b4f3s1_trans_kernel_packn_fp16(struct csinn_tensor *src_kernel,
  231. struct csinn_tensor *dst_kernel);
  232. void shl_c908_ncxhwx_wg_b4f3s1_trans_kernel_packn_int8(struct csinn_tensor *src_kernel,
  233. struct csinn_tensor *dst_kernel);
  234. int shl_c908_ncxhwx_wg_b4f3s1_packn_fp32(struct csinn_tensor *input, struct csinn_tensor *output,
  235. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  236. struct csinn_conv2d_params *params);
  237. int shl_c908_ncxhwx_wg_b4f3s1_packn_fp16(struct csinn_tensor *input, struct csinn_tensor *output,
  238. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  239. struct csinn_conv2d_params *params);
  240. int shl_c908_ncxhwx_wg_b4f3s1_packn_int8(struct csinn_tensor *input, struct csinn_tensor *output,
  241. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  242. struct csinn_conv2d_params *params);
  243. /*********************************** gemm ncxhwx kernel ***********************************/
  244. void shl_c908_ncxhwx_gemm_12xpack2n_fp32(float *dst, const float *sa, const float *sb,
  245. const float *bias, int m, int k, int n, bool fuse_relu);
  246. void shl_c908_ncxhwx_gemm_12xpack2n_fp16(__fp16 *dst, const __fp16 *sa, const __fp16 *sb,
  247. const __fp16 *bias, int m, int k, int n, bool fuse_relu);
  248. void shl_c908_ncxhwx_gemm_12xpackn_int8(int8_t *dst, const int8_t *sa, const int8_t *sb,
  249. const int32_t *bias, int m, int k, int n, int32_t out_zp,
  250. int32_t *mult, int32_t *shift);
  251. void shl_c908_ncxhwx_gemm_12xpackn_int16(int32_t *dst, const int16_t *sa, const int16_t *sb, int m,
  252. int k, int n);
  253. /*********************************** gemm kernel ***********************************/
  254. void shl_c908_reorder_kernel_n8_fp32(float *src, float *dst, int m, int k, int ldc);
  255. void shl_c908_reorder_input_z12_fp32(float *src, float *dst, int k, int n, int ldc);
  256. void shl_c908_gemm_8x12_fp32(float *dst, const float *sa, const float *sb, float *bias, int m,
  257. int k, int n, int ldc);
  258. void shl_c908_reorder_kernel_n8_fp16(__fp16 *src, __fp16 *dst, int m, int k, int ldc);
  259. void shl_c908_reorder_input_z24_fp16(__fp16 *src, __fp16 *dst, int k, int n, int ldc);
  260. void shl_c908_gemm_8x24_fp16(__fp16 *dst, const __fp16 *sa, const __fp16 *sb, __fp16 *bias, int m,
  261. int k, int n, int ldc);
  262. void shl_c908_reorder_kernel_n8_int8(int8_t *src, int8_t *dst, int m, int k, int ldc);
  263. void shl_c908_reorder_input_z8_int8(int8_t *src, int8_t *dst, int k, int n, int ldc);
  264. void shl_c908_gemm_8x8_int8(int8_t *dst, const int8_t *sa, const int8_t *sb, int32_t *bias, int m,
  265. int k, int n, int ldc, int32_t out_zp, int32_t *mult, int32_t *shift);
  266. void shl_c908_reorder_input_z12_int8(int8_t *src, int8_t *dst, int k, int n, int ldc);
  267. /*********************************** VLEN = 256 ***********************************/
  268. /*********************************** VLEN = 256 ***********************************/
  269. /*********************************** VLEN = 256 ***********************************/
  270. void shl_c908_reorder_input_z16_fp32_v256(float *src, float *dst, int k, int n, int ldc);
  271. void shl_c908_gemm_8x16_fp32_v256(float *dst, const float *sa, const float *sb, float *bias, int m,
  272. int k, int n, int ldc);
  273. void shl_c908_reorder_input_z32_fp16_v256(__fp16 *src, __fp16 *dst, int k, int n, int ldc);
  274. void shl_c908_gemm_8x32_fp16_v256(__fp16 *dst, const __fp16 *sa, const __fp16 *sb, __fp16 *bias,
  275. int m, int k, int n, int ldc);
  276. void shl_c908_reorder_input_z16_int8_v256(int8_t *src, int8_t *dst, int k, int n, int ldc);
  277. void shl_c908_gemm_8x16_int8_v256(int8_t *dst, const int8_t *sa, const int8_t *sb, int32_t *bias,
  278. int m, int k, int n, int ldc, int32_t out_zp, int32_t *mult,
  279. int32_t *shift);
  280. #ifdef SHL_UNUSED_REGISTER_BLK
  281. void shl_c908_reorder_input_z8_fp32(float *src, float *dst, int k, int n, int ldc);
  282. void shl_c908_gemm_8x8_fp32(float *dst, const float *sa, const float *sb, float *bias, int m, int k,
  283. int n, int ldc);
  284. void shl_c908_reorder_input_z16_fp16(__fp16 *src, __fp16 *dst, int k, int n, int ldc);
  285. void shl_c908_gemm_8x16_fp16(__fp16 *dst, const __fp16 *sa, const __fp16 *sb, __fp16 *bias, int m,
  286. int k, int n, int ldc);
  287. void shl_c908_reorder_input_z24_fp32_v256(float *src, float *dst, int k, int n, int ldc);
  288. void shl_c908_gemm_8x24_fp32_v256(float *dst, const float *sa, const float *sb, float *bias, int m,
  289. int k, int n, int ldc);
  290. void shl_c908_reorder_input_z48_fp16_v256(__fp16 *src, __fp16 *dst, int k, int n, int ldc);
  291. void shl_c908_gemm_8x48_fp16_v256(__fp16 *dst, const __fp16 *sa, const __fp16 *sb, __fp16 *bias,
  292. int m, int k, int n, int ldc);
  293. #endif
  294. #ifdef SHL_USE_DOT_INT4
  295. int shl_c908_conv2d_init_int4(struct csinn_tensor *input, struct csinn_tensor *output,
  296. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  297. struct csinn_conv2d_params *params);
  298. int shl_c908_depthwise_conv2d_init_int4(struct csinn_tensor *input, struct csinn_tensor *output,
  299. struct csinn_tensor *kernel, struct csinn_tensor *bias,
  300. struct csinn_conv2d_params *params);
  301. #endif
  302. #endif // INCLUDE_SHL_C908_H_