wsola_internals.cc 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. // Copyright 2013 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "media/filters/wsola_internals.h"
  5. #include <algorithm>
  6. #include <cmath>
  7. #include <cstring>
  8. #include <limits>
  9. #include <memory>
  10. #include "base/check_op.h"
  11. #include "base/numerics/math_constants.h"
  12. #include "build/build_config.h"
  13. #include "media/base/audio_bus.h"
  14. #if defined(ARCH_CPU_X86_FAMILY)
  15. #define USE_SIMD 1
  16. #include <xmmintrin.h>
  17. #elif defined(ARCH_CPU_ARM_FAMILY) && defined(USE_NEON)
  18. #define USE_SIMD 1
  19. #include <arm_neon.h>
  20. #endif
  21. namespace media {
  22. namespace internal {
  23. bool InInterval(int n, Interval q) {
  24. return n >= q.first && n <= q.second;
  25. }
  26. float MultiChannelSimilarityMeasure(const float* dot_prod_a_b,
  27. const float* energy_a,
  28. const float* energy_b,
  29. int channels) {
  30. const float kEpsilon = 1e-12f;
  31. float similarity_measure = 0.0f;
  32. for (int n = 0; n < channels; ++n) {
  33. similarity_measure +=
  34. dot_prod_a_b[n] / std::sqrt(energy_a[n] * energy_b[n] + kEpsilon);
  35. }
  36. return similarity_measure;
  37. }
  38. void MultiChannelDotProduct(const AudioBus* a,
  39. int frame_offset_a,
  40. const AudioBus* b,
  41. int frame_offset_b,
  42. int num_frames,
  43. float* dot_product) {
  44. DCHECK_EQ(a->channels(), b->channels());
  45. DCHECK_GE(frame_offset_a, 0);
  46. DCHECK_GE(frame_offset_b, 0);
  47. DCHECK_LE(frame_offset_a + num_frames, a->frames());
  48. DCHECK_LE(frame_offset_b + num_frames, b->frames());
  49. // SIMD optimized variants can provide a massive speedup to this operation.
  50. #if defined(USE_SIMD)
  51. const int rem = num_frames % 4;
  52. const int last_index = num_frames - rem;
  53. const int channels = a->channels();
  54. for (int ch = 0; ch < channels; ++ch) {
  55. const float* a_src = a->channel(ch) + frame_offset_a;
  56. const float* b_src = b->channel(ch) + frame_offset_b;
  57. #if defined(ARCH_CPU_X86_FAMILY)
  58. // First sum all components.
  59. __m128 m_sum = _mm_setzero_ps();
  60. for (int s = 0; s < last_index; s += 4) {
  61. m_sum = _mm_add_ps(
  62. m_sum, _mm_mul_ps(_mm_loadu_ps(a_src + s), _mm_loadu_ps(b_src + s)));
  63. }
  64. // Reduce to a single float for this channel. Sadly, SSE1,2 doesn't have a
  65. // horizontal sum function, so we have to condense manually.
  66. m_sum = _mm_add_ps(_mm_movehl_ps(m_sum, m_sum), m_sum);
  67. _mm_store_ss(dot_product + ch,
  68. _mm_add_ss(m_sum, _mm_shuffle_ps(m_sum, m_sum, 1)));
  69. #elif defined(ARCH_CPU_ARM_FAMILY)
  70. // First sum all components.
  71. float32x4_t m_sum = vmovq_n_f32(0);
  72. for (int s = 0; s < last_index; s += 4)
  73. m_sum = vmlaq_f32(m_sum, vld1q_f32(a_src + s), vld1q_f32(b_src + s));
  74. // Reduce to a single float for this channel.
  75. float32x2_t m_half = vadd_f32(vget_high_f32(m_sum), vget_low_f32(m_sum));
  76. dot_product[ch] = vget_lane_f32(vpadd_f32(m_half, m_half), 0);
  77. #endif
  78. }
  79. if (!rem)
  80. return;
  81. num_frames = rem;
  82. frame_offset_a += last_index;
  83. frame_offset_b += last_index;
  84. #else
  85. memset(dot_product, 0, sizeof(*dot_product) * a->channels());
  86. #endif // defined(USE_SIMD)
  87. // C version is required to handle remainder of frames (% 4 != 0)
  88. for (int k = 0; k < a->channels(); ++k) {
  89. const float* ch_a = a->channel(k) + frame_offset_a;
  90. const float* ch_b = b->channel(k) + frame_offset_b;
  91. for (int n = 0; n < num_frames; ++n)
  92. dot_product[k] += *ch_a++ * *ch_b++;
  93. }
  94. }
  95. void MultiChannelMovingBlockEnergies(const AudioBus* input,
  96. int frames_per_block,
  97. float* energy) {
  98. int num_blocks = input->frames() - (frames_per_block - 1);
  99. int channels = input->channels();
  100. for (int k = 0; k < input->channels(); ++k) {
  101. const float* input_channel = input->channel(k);
  102. energy[k] = 0;
  103. // First block of channel |k|.
  104. for (int m = 0; m < frames_per_block; ++m) {
  105. energy[k] += input_channel[m] * input_channel[m];
  106. }
  107. const float* slide_out = input_channel;
  108. const float* slide_in = input_channel + frames_per_block;
  109. for (int n = 1; n < num_blocks; ++n, ++slide_in, ++slide_out) {
  110. energy[k + n * channels] = energy[k + (n - 1) * channels] - *slide_out *
  111. *slide_out + *slide_in * *slide_in;
  112. }
  113. }
  114. }
  115. // Fit the curve f(x) = a * x^2 + b * x + c such that
  116. // f(-1) = y[0]
  117. // f(0) = y[1]
  118. // f(1) = y[2]
  119. // and return the maximum, assuming that y[0] <= y[1] >= y[2].
  120. void QuadraticInterpolation(const float* y_values,
  121. float* extremum,
  122. float* extremum_value) {
  123. float a = 0.5f * (y_values[2] + y_values[0]) - y_values[1];
  124. float b = 0.5f * (y_values[2] - y_values[0]);
  125. float c = y_values[1];
  126. if (a == 0.f) {
  127. // The coordinates are colinear (within floating-point error).
  128. *extremum = 0;
  129. *extremum_value = y_values[1];
  130. } else {
  131. *extremum = -b / (2.f * a);
  132. *extremum_value = a * (*extremum) * (*extremum) + b * (*extremum) + c;
  133. }
  134. }
  135. int DecimatedSearch(int decimation,
  136. Interval exclude_interval,
  137. const AudioBus* target_block,
  138. const AudioBus* search_segment,
  139. const float* energy_target_block,
  140. const float* energy_candidate_blocks) {
  141. int channels = search_segment->channels();
  142. int block_size = target_block->frames();
  143. int num_candidate_blocks = search_segment->frames() - (block_size - 1);
  144. std::unique_ptr<float[]> dot_prod(new float[channels]);
  145. float similarity[3]; // Three elements for cubic interpolation.
  146. int n = 0;
  147. MultiChannelDotProduct(target_block, 0, search_segment, n, block_size,
  148. dot_prod.get());
  149. similarity[0] = MultiChannelSimilarityMeasure(
  150. dot_prod.get(), energy_target_block,
  151. &energy_candidate_blocks[n * channels], channels);
  152. // Set the starting point as optimal point.
  153. float best_similarity = similarity[0];
  154. int optimal_index = 0;
  155. n += decimation;
  156. if (n >= num_candidate_blocks) {
  157. return 0;
  158. }
  159. MultiChannelDotProduct(target_block, 0, search_segment, n, block_size,
  160. dot_prod.get());
  161. similarity[1] = MultiChannelSimilarityMeasure(
  162. dot_prod.get(), energy_target_block,
  163. &energy_candidate_blocks[n * channels], channels);
  164. n += decimation;
  165. if (n >= num_candidate_blocks) {
  166. // We cannot do any more sampling. Compare these two values and return the
  167. // optimal index.
  168. return similarity[1] > similarity[0] ? decimation : 0;
  169. }
  170. for (; n < num_candidate_blocks; n += decimation) {
  171. MultiChannelDotProduct(target_block, 0, search_segment, n, block_size,
  172. dot_prod.get());
  173. similarity[2] = MultiChannelSimilarityMeasure(
  174. dot_prod.get(), energy_target_block,
  175. &energy_candidate_blocks[n * channels], channels);
  176. if ((similarity[1] > similarity[0] && similarity[1] >= similarity[2]) ||
  177. (similarity[1] >= similarity[0] && similarity[1] > similarity[2])) {
  178. // A local maximum is found. Do a cubic interpolation for a better
  179. // estimate of candidate maximum.
  180. float normalized_candidate_index;
  181. float candidate_similarity;
  182. QuadraticInterpolation(similarity, &normalized_candidate_index,
  183. &candidate_similarity);
  184. int candidate_index = n - decimation + static_cast<int>(
  185. normalized_candidate_index * decimation + 0.5f);
  186. if (candidate_similarity > best_similarity &&
  187. !InInterval(candidate_index, exclude_interval)) {
  188. optimal_index = candidate_index;
  189. best_similarity = candidate_similarity;
  190. }
  191. } else if (n + decimation >= num_candidate_blocks &&
  192. similarity[2] > best_similarity &&
  193. !InInterval(n, exclude_interval)) {
  194. // If this is the end-point and has a better similarity-measure than
  195. // optimal, then we accept it as optimal point.
  196. optimal_index = n;
  197. best_similarity = similarity[2];
  198. }
  199. memmove(similarity, &similarity[1], 2 * sizeof(*similarity));
  200. }
  201. return optimal_index;
  202. }
  203. int FullSearch(int low_limit,
  204. int high_limit,
  205. Interval exclude_interval,
  206. const AudioBus* target_block,
  207. const AudioBus* search_block,
  208. const float* energy_target_block,
  209. const float* energy_candidate_blocks) {
  210. int channels = search_block->channels();
  211. int block_size = target_block->frames();
  212. std::unique_ptr<float[]> dot_prod(new float[channels]);
  213. float best_similarity = std::numeric_limits<float>::min();
  214. int optimal_index = 0;
  215. for (int n = low_limit; n <= high_limit; ++n) {
  216. if (InInterval(n, exclude_interval)) {
  217. continue;
  218. }
  219. MultiChannelDotProduct(target_block, 0, search_block, n, block_size,
  220. dot_prod.get());
  221. float similarity = MultiChannelSimilarityMeasure(
  222. dot_prod.get(), energy_target_block,
  223. &energy_candidate_blocks[n * channels], channels);
  224. if (similarity > best_similarity) {
  225. best_similarity = similarity;
  226. optimal_index = n;
  227. }
  228. }
  229. return optimal_index;
  230. }
  231. int OptimalIndex(const AudioBus* search_block,
  232. const AudioBus* target_block,
  233. Interval exclude_interval) {
  234. int channels = search_block->channels();
  235. DCHECK_EQ(channels, target_block->channels());
  236. int target_size = target_block->frames();
  237. int num_candidate_blocks = search_block->frames() - (target_size - 1);
  238. // This is a compromise between complexity reduction and search accuracy. I
  239. // don't have a proof that down sample of order 5 is optimal. One can compute
  240. // a decimation factor that minimizes complexity given the size of
  241. // |search_block| and |target_block|. However, my experiments show the rate of
  242. // missing the optimal index is significant. This value is chosen
  243. // heuristically based on experiments.
  244. const int kSearchDecimation = 5;
  245. std::unique_ptr<float[]> energy_target_block(new float[channels]);
  246. std::unique_ptr<float[]> energy_candidate_blocks(
  247. new float[channels * num_candidate_blocks]);
  248. // Energy of all candid frames.
  249. MultiChannelMovingBlockEnergies(search_block, target_size,
  250. energy_candidate_blocks.get());
  251. // Energy of target frame.
  252. MultiChannelDotProduct(target_block, 0, target_block, 0,
  253. target_size, energy_target_block.get());
  254. int optimal_index = DecimatedSearch(kSearchDecimation,
  255. exclude_interval, target_block,
  256. search_block, energy_target_block.get(),
  257. energy_candidate_blocks.get());
  258. int lim_low = std::max(0, optimal_index - kSearchDecimation);
  259. int lim_high = std::min(num_candidate_blocks - 1,
  260. optimal_index + kSearchDecimation);
  261. return FullSearch(lim_low, lim_high, exclude_interval, target_block,
  262. search_block, energy_target_block.get(),
  263. energy_candidate_blocks.get());
  264. }
  265. void GetPeriodicHanningWindow(int window_length, float* window) {
  266. const float scale = 2.0f * base::kPiFloat / window_length;
  267. for (int n = 0; n < window_length; ++n)
  268. window[n] = 0.5f * (1.0f - std::cos(n * scale));
  269. }
  270. } // namespace internal
  271. } // namespace media