detect.cpp 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. #include <math.h>
  2. #include <stdio.h>
  3. #include <string>
  4. #include <vector>
  5. #include <algorithm>
  6. #include "detect.h"
  7. int num_class = 21;
  8. float nms_threshold = 0.45f;
  9. int nms_top_k = 100;
  10. int keep_top_k = 100;
  11. float confidence_threshold = 0.51f;
  12. static inline float intersection_area(const BBoxRect& a, const BBoxRect& b)
  13. {
  14. if (a.xmin > b.xmax || a.xmax < b.xmin || a.ymin > b.ymax || a.ymax < b.ymin)
  15. {
  16. // no intersection
  17. return 0.f;
  18. }
  19. float inter_width = std::min(a.xmax, b.xmax) - std::max(a.xmin, b.xmin);
  20. float inter_height = std::min(a.ymax, b.ymax) - std::max(a.ymin, b.ymin);
  21. return inter_width * inter_height;
  22. }
  23. template <typename T>
  24. static void qsort_descent_inplace(std::vector<T>& datas, std::vector<float>& scores, int left, int right)
  25. {
  26. int i = left;
  27. int j = right;
  28. //must be instanced ... so this function must include vector
  29. float p = scores[(left + right) / 2];
  30. while (i <= j)
  31. {
  32. while (scores[i] > p)
  33. i++;
  34. while (scores[j] < p)
  35. j--;
  36. if (i <= j)
  37. {
  38. // swap
  39. std::swap(datas[i], datas[j]);
  40. std::swap(scores[i], scores[j]);
  41. i++;
  42. j--;
  43. }
  44. }
  45. if (left < j)
  46. qsort_descent_inplace(datas, scores, left, j);
  47. if (i < right)
  48. qsort_descent_inplace(datas, scores, i, right);
  49. }
  50. template <typename T>
  51. static void qsort_descent_inplace(std::vector<T>& datas, std::vector<float>& scores)
  52. {
  53. if (datas.empty() || scores.empty())
  54. return;
  55. qsort_descent_inplace(datas, scores, 0, scores.size() - 1);
  56. }
  57. static void nms_sorted_bboxes(const std::vector<BBoxRect>& bboxes, std::vector<int>& picked, float nms_threshold)
  58. {
  59. picked.clear();
  60. const int n = bboxes.size();
  61. std::vector<float> areas(n);
  62. for (int i = 0; i < n; i++)
  63. {
  64. const BBoxRect& r = bboxes[i];
  65. float width = r.xmax - r.xmin;
  66. float height = r.ymax - r.ymin;
  67. areas[i] = width * height;
  68. }
  69. for (int i = 0; i < n; i++)
  70. {
  71. const BBoxRect& a = bboxes[i];
  72. int keep = 1;
  73. for (int j = 0; j < (int)picked.size(); j++)
  74. {
  75. const BBoxRect& b = bboxes[picked[j]];
  76. float interarea = intersection_area(a, b);
  77. float unionarea = areas[i] + areas[picked[j]] - interarea;
  78. if (interarea / unionarea > nms_threshold)
  79. keep = 0;
  80. }
  81. if (keep)
  82. picked.push_back(i);
  83. }
  84. }
  85. int ssdforward(float *location,float * confidence,float * priorbox,BBox *bboxes,BBoxOut *out)
  86. {
  87. const float* location_ptr = location;
  88. const float* priorbox_ptr = priorbox;
  89. const float* variance_ptr = priorbox + num_prior*4;
  90. for (int i = 0; i < num_prior; i++)
  91. {
  92. const float* loc = location_ptr + i * 4;
  93. const float* pb = priorbox_ptr + i * 4;
  94. const float* var = variance_ptr;// + i * 4;
  95. float* bbox = (float*)&bboxes[i];// bboxes.row(i);
  96. // CENTER_SIZE
  97. float pb_w = pb[2] - pb[0];
  98. float pb_h = pb[3] - pb[1];
  99. float pb_cx = (pb[0] + pb[2]) * 0.5f;
  100. float pb_cy = (pb[1] + pb[3]) * 0.5f;
  101. float bbox_cx = var[0] * loc[0] * pb_w + pb_cx;
  102. float bbox_cy = var[1] * loc[1] * pb_h + pb_cy;
  103. float bbox_w = exp(var[2] * loc[2]) * pb_w;
  104. float bbox_h = exp(var[3] * loc[3]) * pb_h;
  105. bbox[0] = bbox_cx - bbox_w * 0.5f;
  106. bbox[1] = bbox_cy - bbox_h * 0.5f;
  107. bbox[2] = bbox_cx + bbox_w * 0.5f;
  108. bbox[3] = bbox_cy + bbox_h * 0.5f;
  109. }
  110. // sort and nms for each class
  111. std::vector< std::vector<BBoxRect> > all_class_bbox_rects;
  112. std::vector< std::vector<float> > all_class_bbox_scores;
  113. all_class_bbox_rects.resize(num_class);
  114. all_class_bbox_scores.resize(num_class);
  115. // start from 1 to ignore background class
  116. for (int i = 1; i < num_class; i++)
  117. {
  118. // filter by confidence_threshold
  119. std::vector<BBoxRect> class_bbox_rects;
  120. std::vector<float> class_bbox_scores;
  121. for (int j = 0; j < num_prior; j++)
  122. {
  123. float score = confidence[j * num_class + i];
  124. if (score > confidence_threshold)
  125. {
  126. const float* bbox = (float*)&bboxes[j];
  127. BBoxRect c = { bbox[0], bbox[1], bbox[2], bbox[3], i };
  128. class_bbox_rects.push_back(c);
  129. class_bbox_scores.push_back(score);
  130. }
  131. }
  132. // sort inplace
  133. qsort_descent_inplace(class_bbox_rects, class_bbox_scores);
  134. // keep nms_top_k
  135. if (nms_top_k < (int)class_bbox_rects.size())
  136. {
  137. class_bbox_rects.resize(nms_top_k);
  138. class_bbox_scores.resize(nms_top_k);
  139. }
  140. // apply nms
  141. std::vector<int> picked;
  142. nms_sorted_bboxes(class_bbox_rects, picked, nms_threshold);
  143. // select
  144. for (int j = 0; j < (int)picked.size(); j++)
  145. {
  146. int z = picked[j];
  147. all_class_bbox_rects[i].push_back(class_bbox_rects[z]);
  148. all_class_bbox_scores[i].push_back(class_bbox_scores[z]);
  149. }
  150. }
  151. // gather all class
  152. std::vector<BBoxRect> bbox_rects;
  153. std::vector<float> bbox_scores;
  154. for (int i = 1; i < num_class; i++)
  155. {
  156. const std::vector<BBoxRect>& class_bbox_rects = all_class_bbox_rects[i];
  157. const std::vector<float>& class_bbox_scores = all_class_bbox_scores[i];
  158. bbox_rects.insert(bbox_rects.end(), class_bbox_rects.begin(), class_bbox_rects.end());
  159. bbox_scores.insert(bbox_scores.end(), class_bbox_scores.begin(), class_bbox_scores.end());
  160. }
  161. // global sort inplace
  162. qsort_descent_inplace(bbox_rects, bbox_scores);
  163. // keep_top_k
  164. if (keep_top_k < (int)bbox_rects.size())
  165. {
  166. bbox_rects.resize(keep_top_k);
  167. bbox_scores.resize(keep_top_k);
  168. }
  169. int num_detected = bbox_rects.size();
  170. if(num_detected >100) num_detected =100;
  171. for (int i = 0; i < num_detected; i++)
  172. {
  173. const BBoxRect& r = bbox_rects[i];
  174. float score = bbox_scores[i];
  175. float* outptr = (float*)&out[i];
  176. int *labelptr = (int *)outptr;
  177. labelptr[0] = r.label;
  178. outptr[1] = score;
  179. outptr[2] = r.xmin;
  180. outptr[3] = r.ymin;
  181. outptr[4] = r.xmax;
  182. outptr[5] = r.ymax;
  183. }
  184. return num_detected;
  185. }
  186. int readbintomem(float *dst,char *path)
  187. {
  188. FILE *pFile = fopen (path, "rb" );
  189. if (pFile==NULL)
  190. {
  191. fputs ("File error",stderr);
  192. exit (1);
  193. }
  194. fseek (pFile , 0 , SEEK_END);
  195. int fsize = ftell(pFile);
  196. rewind (pFile);
  197. //buffer = (char*) malloc (sizeof(char)*lSize);
  198. int result = fread (dst,1,fsize,pFile);
  199. fclose(pFile);
  200. return result;
  201. }