distribution_reporter.h 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. // Copyright 2018 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #ifndef MEDIA_LEARNING_IMPL_DISTRIBUTION_REPORTER_H_
  5. #define MEDIA_LEARNING_IMPL_DISTRIBUTION_REPORTER_H_
  6. #include <set>
  7. #include "base/callback.h"
  8. #include "base/component_export.h"
  9. #include "base/memory/weak_ptr.h"
  10. #include "media/learning/common/learning_task.h"
  11. #include "media/learning/common/target_histogram.h"
  12. #include "media/learning/impl/model.h"
  13. #include "services/metrics/public/cpp/ukm_source_id.h"
  14. #include "third_party/abseil-cpp/absl/types/optional.h"
  15. namespace media {
  16. namespace learning {
  17. // Helper class to report on predicted distrubutions vs target distributions.
  18. // Use DistributionReporter::Create() to create one that's appropriate for a
  19. // specific learning task.
  20. class COMPONENT_EXPORT(LEARNING_IMPL) DistributionReporter {
  21. public:
  22. // Extra information provided to the reporter for each prediction.
  23. struct PredictionInfo {
  24. // What value was observed?
  25. TargetValue observed;
  26. // UKM source id to use when logging this result.
  27. // This will be filled in by the LearningTaskController. For example, the
  28. // MojoLearningTaskControllerService will be created in the browser by the
  29. // MediaMetricsProvider, which gets the SourceId via callback from the
  30. // RenderFrameHostDelegate on construction.
  31. //
  32. // TODO(liberato): Right now, this is not filled in anywhere. When the
  33. // mojo service is created (MediaMetricsProvider), record the source id and
  34. // memorize it in any MojoLearningTaskControllerService that's created by
  35. // the MediaMetricsProvider, either directly or in a wrapper for the
  36. // mojo controller.
  37. ukm::SourceId source_id = ukm::kInvalidSourceId;
  38. // Total weight of the training data used to create this model.
  39. double total_training_weight = 0.;
  40. // Total number of examples (unweighted) in the training set.
  41. size_t total_training_examples = 0u;
  42. // TODO(liberato): Move the feature subset here.
  43. };
  44. // Create a DistributionReporter that's suitable for |task|.
  45. static std::unique_ptr<DistributionReporter> Create(const LearningTask& task);
  46. DistributionReporter(const DistributionReporter&) = delete;
  47. DistributionReporter& operator=(const DistributionReporter&) = delete;
  48. virtual ~DistributionReporter();
  49. // Returns a prediction CB that will be compared to |prediction_info.observed|
  50. // TODO(liberato): This is too complicated. Skip the callback and just call
  51. // us with the predicted value.
  52. virtual Model::PredictionCB GetPredictionCallback(
  53. const PredictionInfo& prediction_info);
  54. // Set the subset of features that is being used to train the model. This is
  55. // used for feature importance measuremnts.
  56. //
  57. // For example, sending in the set [0, 3, 7] would indicate that the model was
  58. // trained with task().feature_descriptions[0, 3, 7] only.
  59. //
  60. // Note that UMA reporting only supports single feature subsets.
  61. void SetFeatureSubset(const std::set<int>& feature_indices);
  62. protected:
  63. DistributionReporter(const LearningTask& task);
  64. const LearningTask& task() const { return task_; }
  65. // Implemented by subclasses to report a prediction.
  66. virtual void OnPrediction(const PredictionInfo& prediction_info,
  67. TargetHistogram predicted) = 0;
  68. const absl::optional<std::set<int>>& feature_indices() const {
  69. return feature_indices_;
  70. }
  71. private:
  72. LearningTask task_;
  73. // If provided, then these are the features that are used to train the model.
  74. // Otherwise, we assume that all features are used.
  75. absl::optional<std::set<int>> feature_indices_;
  76. base::WeakPtrFactory<DistributionReporter> weak_factory_{this};
  77. };
  78. } // namespace learning
  79. } // namespace media
  80. #endif // MEDIA_LEARNING_IMPL_DISTRIBUTION_REPORTER_H_