weighted_moving_linear_regression.cc 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. // Copyright 2018 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "chromecast/base/statistics/weighted_moving_linear_regression.h"
  5. #include <math.h>
  6. #include <algorithm>
  7. #include "base/check_op.h"
  8. #include "base/logging.h"
  9. namespace chromecast {
  10. WeightedMovingLinearRegression::WeightedMovingLinearRegression(
  11. int64_t max_x_range)
  12. : max_x_range_(max_x_range) {
  13. DCHECK_GE(max_x_range_, 0);
  14. }
  15. WeightedMovingLinearRegression::~WeightedMovingLinearRegression() = default;
  16. void WeightedMovingLinearRegression::Reserve(int count) {
  17. Sample sample = {0, 0, 0};
  18. samples_.insert(samples_.end(), count, sample);
  19. samples_.erase(samples_.end() - count, samples_.end());
  20. }
  21. void WeightedMovingLinearRegression::Reset() {
  22. x_mean_.Reset();
  23. y_mean_.Reset();
  24. covariance_ = 0.0;
  25. samples_.clear();
  26. slope_ = 0.0;
  27. slope_variance_ = 0.0;
  28. intercept_variance_ = 0.0;
  29. has_estimate_ = false;
  30. }
  31. void WeightedMovingLinearRegression::AddSample(int64_t x,
  32. int64_t y,
  33. double weight) {
  34. DCHECK_GE(weight, 0);
  35. if (!samples_.empty())
  36. DCHECK_GE(x, samples_.back().x);
  37. UpdateSet(x, y, weight);
  38. Sample sample = {x, y, weight};
  39. samples_.push_back(sample);
  40. // Remove old samples.
  41. while (x - samples_.front().x > max_x_range_) {
  42. const Sample& old_sample = samples_.front();
  43. UpdateSet(old_sample.x, old_sample.y, -old_sample.weight);
  44. samples_.pop_front();
  45. }
  46. DCHECK(!samples_.empty());
  47. if (samples_.size() <= 2 || x_mean_.sum_weights() == 0 ||
  48. x_mean_.variance_sum() == 0) {
  49. has_estimate_ = false;
  50. return;
  51. }
  52. slope_ = covariance_ / x_mean_.variance_sum();
  53. double residual_sum_squares =
  54. (covariance_ * covariance_) / x_mean_.variance_sum();
  55. double mean_squared_error =
  56. (y_mean_.variance_sum() - residual_sum_squares) / (samples_.size() - 2);
  57. slope_variance_ = std::max(0.0, mean_squared_error / x_mean_.variance_sum());
  58. intercept_variance_ = std::max(
  59. 0.0, (slope_variance_ * x_mean_.variance_sum()) / x_mean_.sum_weights());
  60. has_estimate_ = true;
  61. }
  62. bool WeightedMovingLinearRegression::EstimateY(int64_t x,
  63. int64_t* y,
  64. double* error) const {
  65. if (!has_estimate_)
  66. return false;
  67. double x_diff = x - x_mean_.weighted_mean();
  68. double y_estimate = y_mean_.weighted_mean() + (slope_ * x_diff);
  69. *y = static_cast<int64_t>(round(y_estimate));
  70. *error = sqrt(intercept_variance_ + (slope_variance_ * x_diff * x_diff));
  71. return true;
  72. }
  73. bool WeightedMovingLinearRegression::EstimateSlope(double* slope,
  74. double* error) const {
  75. if (!has_estimate_)
  76. return false;
  77. *slope = slope_;
  78. *error = sqrt(slope_variance_);
  79. return true;
  80. }
  81. void WeightedMovingLinearRegression::UpdateSet(int64_t x,
  82. int64_t y,
  83. double weight) {
  84. double old_y_mean = y_mean_.weighted_mean();
  85. x_mean_.AddSample(x, weight);
  86. y_mean_.AddSample(y, weight);
  87. covariance_ += weight * (x - x_mean_.weighted_mean()) * (y - old_y_mean);
  88. }
  89. void WeightedMovingLinearRegression::DumpSamples() const {
  90. for (auto sample : samples_) {
  91. LOG(INFO) << "x, y, weight: " << sample.x << " " << sample.y << " "
  92. << sample.weight;
  93. }
  94. }
  95. } // namespace chromecast