detect.cpp 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. #include <math.h>
  2. #include <stdio.h>
  3. #include <string>
  4. #include <vector>
  5. #include <algorithm>
  6. #include "detect.h"
  7. int num_class = 21;
  8. float nms_threshold = 0.45f;
  9. int nms_top_k = 100;
  10. int keep_top_k = 100;
  11. //float confidence_threshold = 0.51f;
  12. float confidence_threshold = 0.45f;
  13. //float confidence_threshold = 0.48f;
  14. static inline float intersection_area(const BBoxRect& a, const BBoxRect& b)
  15. {
  16. if (a.xmin > b.xmax || a.xmax < b.xmin || a.ymin > b.ymax || a.ymax < b.ymin)
  17. {
  18. // no intersection
  19. return 0.f;
  20. }
  21. float inter_width = std::min(a.xmax, b.xmax) - std::max(a.xmin, b.xmin);
  22. float inter_height = std::min(a.ymax, b.ymax) - std::max(a.ymin, b.ymin);
  23. return inter_width * inter_height;
  24. }
  25. template <typename T>
  26. static void qsort_descent_inplace(std::vector<T>& datas, std::vector<float>& scores, int left, int right)
  27. {
  28. int i = left;
  29. int j = right;
  30. //must be instanced ... so this function must include vector
  31. float p = scores[(left + right) / 2];
  32. while (i <= j)
  33. {
  34. while (scores[i] > p)
  35. i++;
  36. while (scores[j] < p)
  37. j--;
  38. if (i <= j)
  39. {
  40. // swap
  41. std::swap(datas[i], datas[j]);
  42. std::swap(scores[i], scores[j]);
  43. i++;
  44. j--;
  45. }
  46. }
  47. if (left < j)
  48. qsort_descent_inplace(datas, scores, left, j);
  49. if (i < right)
  50. qsort_descent_inplace(datas, scores, i, right);
  51. }
  52. template <typename T>
  53. static void qsort_descent_inplace(std::vector<T>& datas, std::vector<float>& scores)
  54. {
  55. if (datas.empty() || scores.empty())
  56. return;
  57. qsort_descent_inplace(datas, scores, 0, scores.size() - 1);
  58. }
  59. static void nms_sorted_bboxes(const std::vector<BBoxRect>& bboxes, std::vector<int>& picked, float nms_threshold)
  60. {
  61. picked.clear();
  62. const int n = bboxes.size();
  63. std::vector<float> areas(n);
  64. for (int i = 0; i < n; i++)
  65. {
  66. const BBoxRect& r = bboxes[i];
  67. float width = r.xmax - r.xmin;
  68. float height = r.ymax - r.ymin;
  69. areas[i] = width * height;
  70. }
  71. for (int i = 0; i < n; i++)
  72. {
  73. const BBoxRect& a = bboxes[i];
  74. int keep = 1;
  75. for (int j = 0; j < (int)picked.size(); j++)
  76. {
  77. const BBoxRect& b = bboxes[picked[j]];
  78. float interarea = intersection_area(a, b);
  79. float unionarea = areas[i] + areas[picked[j]] - interarea;
  80. if (interarea / unionarea > nms_threshold)
  81. keep = 0;
  82. }
  83. if (keep)
  84. picked.push_back(i);
  85. }
  86. }
  87. int ssdforward(float *location,float * confidence,float * priorbox,BBox *bboxes,BBoxOut *out)
  88. {
  89. const float* location_ptr = location;
  90. const float* priorbox_ptr = priorbox;
  91. const float* variance_ptr = priorbox + num_prior*4;
  92. for (int i = 0; i < num_prior; i++)
  93. {
  94. const float* loc = location_ptr + i * 4;
  95. const float* pb = priorbox_ptr + i * 4;
  96. const float* var = variance_ptr;// + i * 4;
  97. float* bbox = (float*)&bboxes[i];// bboxes.row(i);
  98. // CENTER_SIZE
  99. float pb_w = pb[2] - pb[0];
  100. float pb_h = pb[3] - pb[1];
  101. float pb_cx = (pb[0] + pb[2]) * 0.5f;
  102. float pb_cy = (pb[1] + pb[3]) * 0.5f;
  103. float bbox_cx = var[0] * loc[0] * pb_w + pb_cx;
  104. float bbox_cy = var[1] * loc[1] * pb_h + pb_cy;
  105. float bbox_w = exp(var[2] * loc[2]) * pb_w;
  106. float bbox_h = exp(var[3] * loc[3]) * pb_h;
  107. bbox[0] = bbox_cx - bbox_w * 0.5f;
  108. bbox[1] = bbox_cy - bbox_h * 0.5f;
  109. bbox[2] = bbox_cx + bbox_w * 0.5f;
  110. bbox[3] = bbox_cy + bbox_h * 0.5f;
  111. }
  112. // sort and nms for each class
  113. std::vector< std::vector<BBoxRect> > all_class_bbox_rects;
  114. std::vector< std::vector<float> > all_class_bbox_scores;
  115. all_class_bbox_rects.resize(num_class);
  116. all_class_bbox_scores.resize(num_class);
  117. // start from 1 to ignore background class
  118. for (int i = 1; i < num_class; i++)
  119. {
  120. // filter by confidence_threshold
  121. std::vector<BBoxRect> class_bbox_rects;
  122. std::vector<float> class_bbox_scores;
  123. for (int j = 0; j < num_prior; j++)
  124. {
  125. float score = confidence[j * num_class + i];
  126. if (score > confidence_threshold)
  127. {
  128. const float* bbox = (float*)&bboxes[j];
  129. BBoxRect c = { bbox[0], bbox[1], bbox[2], bbox[3], i };
  130. class_bbox_rects.push_back(c);
  131. class_bbox_scores.push_back(score);
  132. }
  133. }
  134. // sort inplace
  135. qsort_descent_inplace(class_bbox_rects, class_bbox_scores);
  136. // keep nms_top_k
  137. if (nms_top_k < (int)class_bbox_rects.size())
  138. {
  139. class_bbox_rects.resize(nms_top_k);
  140. class_bbox_scores.resize(nms_top_k);
  141. }
  142. // apply nms
  143. std::vector<int> picked;
  144. nms_sorted_bboxes(class_bbox_rects, picked, nms_threshold);
  145. // select
  146. for (int j = 0; j < (int)picked.size(); j++)
  147. {
  148. int z = picked[j];
  149. all_class_bbox_rects[i].push_back(class_bbox_rects[z]);
  150. all_class_bbox_scores[i].push_back(class_bbox_scores[z]);
  151. }
  152. }
  153. // gather all class
  154. std::vector<BBoxRect> bbox_rects;
  155. std::vector<float> bbox_scores;
  156. for (int i = 1; i < num_class; i++)
  157. {
  158. const std::vector<BBoxRect>& class_bbox_rects = all_class_bbox_rects[i];
  159. const std::vector<float>& class_bbox_scores = all_class_bbox_scores[i];
  160. bbox_rects.insert(bbox_rects.end(), class_bbox_rects.begin(), class_bbox_rects.end());
  161. bbox_scores.insert(bbox_scores.end(), class_bbox_scores.begin(), class_bbox_scores.end());
  162. }
  163. // global sort inplace
  164. qsort_descent_inplace(bbox_rects, bbox_scores);
  165. // keep_top_k
  166. if (keep_top_k < (int)bbox_rects.size())
  167. {
  168. bbox_rects.resize(keep_top_k);
  169. bbox_scores.resize(keep_top_k);
  170. }
  171. int num_detected = bbox_rects.size();
  172. if(num_detected >100) num_detected =100;
  173. for (int i = 0; i < num_detected; i++)
  174. {
  175. const BBoxRect& r = bbox_rects[i];
  176. float score = bbox_scores[i];
  177. float* outptr = (float*)&out[i];
  178. int *labelptr = (int *)outptr;
  179. labelptr[0] = r.label;
  180. outptr[1] = score;
  181. outptr[2] = r.xmin;
  182. outptr[3] = r.ymin;
  183. outptr[4] = r.xmax;
  184. outptr[5] = r.ymax;
  185. }
  186. return num_detected;
  187. }
  188. int readbintomem(float *dst,char *path)
  189. {
  190. FILE *pFile = fopen (path, "rb" );
  191. if (pFile==NULL)
  192. {
  193. fputs ("File error",stderr);
  194. exit (1);
  195. }
  196. fseek (pFile , 0 , SEEK_END);
  197. int fsize = ftell(pFile);
  198. rewind (pFile);
  199. //buffer = (char*) malloc (sizeof(char)*lSize);
  200. int result = fread (dst,1,fsize,pFile);
  201. fclose(pFile);
  202. return result;
  203. }