one_hot.h 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. // Copyright 2018 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #ifndef MEDIA_LEARNING_IMPL_ONE_HOT_H_
  5. #define MEDIA_LEARNING_IMPL_ONE_HOT_H_
  6. #include <map>
  7. #include <memory>
  8. #include <vector>
  9. #include "base/component_export.h"
  10. #include "media/learning/common/labelled_example.h"
  11. #include "media/learning/common/learning_task.h"
  12. #include "media/learning/common/value.h"
  13. #include "media/learning/impl/model.h"
  14. #include "third_party/abseil-cpp/absl/types/optional.h"
  15. namespace media {
  16. namespace learning {
  17. // Converter class that memorizes a mapping from nominal features to numeric
  18. // features with a one-hot encoding.
  19. class COMPONENT_EXPORT(LEARNING_IMPL) OneHotConverter {
  20. public:
  21. // Build a one-hot converter for all nominal features |task|, using the values
  22. // found in |training_data|.
  23. OneHotConverter(const LearningTask& task, const TrainingData& training_data);
  24. OneHotConverter(const OneHotConverter&) = delete;
  25. OneHotConverter& operator=(const OneHotConverter&) = delete;
  26. ~OneHotConverter();
  27. // Return the LearningTask that has only nominal features.
  28. const LearningTask& converted_task() const { return converted_task_; }
  29. // Convert |training_data| to be a one-hot model.
  30. TrainingData Convert(const TrainingData& training_data) const;
  31. // Convert |feature_vector| to match the one-hot model.
  32. FeatureVector Convert(const FeatureVector& feature_vector) const;
  33. private:
  34. // Build a converter for original feature |index|.
  35. void ProcessOneFeature(
  36. size_t index,
  37. const LearningTask::ValueDescription& original_description,
  38. const TrainingData& training_data);
  39. // Learning task with the feature descriptions adjusted for the one-hot model.
  40. LearningTask converted_task_;
  41. // [value] == vector index that should be 1 in the one-hot vector.
  42. using ValueVectorIndexMap = std::map<Value, size_t>;
  43. // [original task feature index] = optional converter for it. If the feature
  44. // was kNumeric to begin with, then there will be no converter.
  45. std::vector<absl::optional<ValueVectorIndexMap>> converters_;
  46. };
  47. // Model that uses |Converter| to convert instances before sending them to the
  48. // underlying model.
  49. class COMPONENT_EXPORT(LEARNING_IMPL) ConvertingModel : public Model {
  50. public:
  51. ConvertingModel(std::unique_ptr<OneHotConverter> converter,
  52. std::unique_ptr<Model> model);
  53. ConvertingModel(const ConvertingModel&) = delete;
  54. ConvertingModel& operator=(const ConvertingModel&) = delete;
  55. ~ConvertingModel() override;
  56. // Model
  57. TargetHistogram PredictDistribution(const FeatureVector& instance) override;
  58. private:
  59. std::unique_ptr<OneHotConverter> converter_;
  60. std::unique_ptr<Model> model_;
  61. };
  62. } // namespace learning
  63. } // namespace media
  64. #endif // MEDIA_LEARNING_IMPL_ONE_HOT_H_