123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- // Copyright 2018 The Chromium Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style license that can be
- // found in the LICENSE file.
- #ifndef MEDIA_LEARNING_IMPL_DISTRIBUTION_REPORTER_H_
- #define MEDIA_LEARNING_IMPL_DISTRIBUTION_REPORTER_H_
- #include <set>
- #include "base/callback.h"
- #include "base/component_export.h"
- #include "base/memory/weak_ptr.h"
- #include "media/learning/common/learning_task.h"
- #include "media/learning/common/target_histogram.h"
- #include "media/learning/impl/model.h"
- #include "services/metrics/public/cpp/ukm_source_id.h"
- #include "third_party/abseil-cpp/absl/types/optional.h"
- namespace media {
- namespace learning {
- // Helper class to report on predicted distrubutions vs target distributions.
- // Use DistributionReporter::Create() to create one that's appropriate for a
- // specific learning task.
- class COMPONENT_EXPORT(LEARNING_IMPL) DistributionReporter {
- public:
- // Extra information provided to the reporter for each prediction.
- struct PredictionInfo {
- // What value was observed?
- TargetValue observed;
- // UKM source id to use when logging this result.
- // This will be filled in by the LearningTaskController. For example, the
- // MojoLearningTaskControllerService will be created in the browser by the
- // MediaMetricsProvider, which gets the SourceId via callback from the
- // RenderFrameHostDelegate on construction.
- //
- // TODO(liberato): Right now, this is not filled in anywhere. When the
- // mojo service is created (MediaMetricsProvider), record the source id and
- // memorize it in any MojoLearningTaskControllerService that's created by
- // the MediaMetricsProvider, either directly or in a wrapper for the
- // mojo controller.
- ukm::SourceId source_id = ukm::kInvalidSourceId;
- // Total weight of the training data used to create this model.
- double total_training_weight = 0.;
- // Total number of examples (unweighted) in the training set.
- size_t total_training_examples = 0u;
- // TODO(liberato): Move the feature subset here.
- };
- // Create a DistributionReporter that's suitable for |task|.
- static std::unique_ptr<DistributionReporter> Create(const LearningTask& task);
- DistributionReporter(const DistributionReporter&) = delete;
- DistributionReporter& operator=(const DistributionReporter&) = delete;
- virtual ~DistributionReporter();
- // Returns a prediction CB that will be compared to |prediction_info.observed|
- // TODO(liberato): This is too complicated. Skip the callback and just call
- // us with the predicted value.
- virtual Model::PredictionCB GetPredictionCallback(
- const PredictionInfo& prediction_info);
- // Set the subset of features that is being used to train the model. This is
- // used for feature importance measuremnts.
- //
- // For example, sending in the set [0, 3, 7] would indicate that the model was
- // trained with task().feature_descriptions[0, 3, 7] only.
- //
- // Note that UMA reporting only supports single feature subsets.
- void SetFeatureSubset(const std::set<int>& feature_indices);
- protected:
- DistributionReporter(const LearningTask& task);
- const LearningTask& task() const { return task_; }
- // Implemented by subclasses to report a prediction.
- virtual void OnPrediction(const PredictionInfo& prediction_info,
- TargetHistogram predicted) = 0;
- const absl::optional<std::set<int>>& feature_indices() const {
- return feature_indices_;
- }
- private:
- LearningTask task_;
- // If provided, then these are the features that are used to train the model.
- // Otherwise, we assume that all features are used.
- absl::optional<std::set<int>> feature_indices_;
- base::WeakPtrFactory<DistributionReporter> weak_factory_{this};
- };
- } // namespace learning
- } // namespace media
- #endif // MEDIA_LEARNING_IMPL_DISTRIBUTION_REPORTER_H_
|