extra_trees_trainer.h 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. // Copyright 2018 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #ifndef MEDIA_LEARNING_IMPL_EXTRA_TREES_TRAINER_H_
  5. #define MEDIA_LEARNING_IMPL_EXTRA_TREES_TRAINER_H_
  6. #include <memory>
  7. #include <vector>
  8. #include "base/component_export.h"
  9. #include "base/memory/weak_ptr.h"
  10. #include "media/learning/common/learning_task.h"
  11. #include "media/learning/impl/one_hot.h"
  12. #include "media/learning/impl/random_number_generator.h"
  13. #include "media/learning/impl/random_tree_trainer.h"
  14. #include "media/learning/impl/training_algorithm.h"
  15. namespace media {
  16. namespace learning {
  17. // Bagged forest of extremely randomized trees.
  18. //
  19. // These are an ensemble of trees. Each tree is constructed from the full
  20. // training set. The trees are constructed by selecting a random subset of
  21. // features at each node. For each feature, a uniformly random split point is
  22. // chosen. The feature with the best randomly chosen split point is used.
  23. //
  24. // These will automatically convert nominal values to one-hot vectors.
  25. class COMPONENT_EXPORT(LEARNING_IMPL) ExtraTreesTrainer
  26. : public TrainingAlgorithm,
  27. public HasRandomNumberGenerator,
  28. public base::SupportsWeakPtr<ExtraTreesTrainer> {
  29. public:
  30. ExtraTreesTrainer();
  31. ExtraTreesTrainer(const ExtraTreesTrainer&) = delete;
  32. ExtraTreesTrainer& operator=(const ExtraTreesTrainer&) = delete;
  33. ~ExtraTreesTrainer() override;
  34. // TrainingAlgorithm
  35. void Train(const LearningTask& task,
  36. const TrainingData& training_data,
  37. TrainedModelCB model_cb) override;
  38. private:
  39. void OnRandomTreeModel(TrainedModelCB model_cb, std::unique_ptr<Model> model);
  40. std::unique_ptr<TrainingAlgorithm> tree_trainer_;
  41. // In-flight training.
  42. LearningTask task_;
  43. std::vector<std::unique_ptr<Model>> trees_;
  44. std::unique_ptr<OneHotConverter> converter_;
  45. TrainingData converted_training_data_;
  46. };
  47. } // namespace learning
  48. } // namespace media
  49. #endif // MEDIA_LEARNING_IMPL_EXTRA_TREES_TRAINER_H_