lookup_table_trainer.cc 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. // Copyright 2019 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "media/learning/impl/lookup_table_trainer.h"
  5. #include <map>
  6. namespace media {
  7. namespace learning {
  8. class LookupTable : public Model {
  9. public:
  10. LookupTable(const TrainingData& training_data) {
  11. for (auto& example : training_data)
  12. buckets_[example.features] += example;
  13. }
  14. // Model
  15. TargetHistogram PredictDistribution(const FeatureVector& instance) override {
  16. auto iter = buckets_.find(instance);
  17. if (iter == buckets_.end())
  18. return TargetHistogram();
  19. return iter->second;
  20. }
  21. private:
  22. std::map<FeatureVector, TargetHistogram> buckets_;
  23. };
  24. LookupTableTrainer::LookupTableTrainer() = default;
  25. LookupTableTrainer::~LookupTableTrainer() = default;
  26. void LookupTableTrainer::Train(const LearningTask& task,
  27. const TrainingData& training_data,
  28. TrainedModelCB model_cb) {
  29. std::unique_ptr<LookupTable> lookup_table =
  30. std::make_unique<LookupTable>(training_data);
  31. // TODO(liberato): post?
  32. std::move(model_cb).Run(std::move(lookup_table));
  33. }
  34. } // namespace learning
  35. } // namespace media