one_hot.cc 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. // Copyright 2018 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "media/learning/impl/one_hot.h"
  5. #include <set>
  6. namespace media {
  7. namespace learning {
  8. OneHotConverter::OneHotConverter(const LearningTask& task,
  9. const TrainingData& training_data)
  10. : converted_task_(task) {
  11. converted_task_.feature_descriptions.clear();
  12. // store
  13. converters_.resize(task.feature_descriptions.size());
  14. for (size_t i = 0; i < task.feature_descriptions.size(); i++) {
  15. const LearningTask::ValueDescription& feature =
  16. task.feature_descriptions[i];
  17. // If this is already a numeric feature, then we will copy it since
  18. // converters[i] will be unset.
  19. if (feature.ordering == LearningTask::Ordering::kNumeric) {
  20. converted_task_.feature_descriptions.push_back(feature);
  21. continue;
  22. }
  23. ProcessOneFeature(i, feature, training_data);
  24. }
  25. }
  26. OneHotConverter::~OneHotConverter() = default;
  27. TrainingData OneHotConverter::Convert(const TrainingData& training_data) const {
  28. TrainingData converted_training_data;
  29. for (auto& example : training_data) {
  30. LabelledExample converted_example(example);
  31. converted_example.features = Convert(example.features);
  32. converted_training_data.push_back(converted_example);
  33. }
  34. return converted_training_data;
  35. }
  36. FeatureVector OneHotConverter::Convert(
  37. const FeatureVector& feature_vector) const {
  38. FeatureVector converted_feature_vector;
  39. converted_feature_vector.reserve(converted_task_.feature_descriptions.size());
  40. for (size_t i = 0; i < converters_.size(); i++) {
  41. auto& converter = converters_[i];
  42. if (!converter) {
  43. // There's no conversion needed for this feature, since it was numeric.
  44. converted_feature_vector.push_back(feature_vector[i]);
  45. continue;
  46. }
  47. // Convert this feature to a one-hot vector.
  48. const size_t vector_size = converter->size();
  49. // Start with a zero-hot vector. Is that a thing?
  50. for (size_t v = 0; v < vector_size; v++)
  51. converted_feature_vector.push_back(FeatureValue(0));
  52. // Set the appropriate entry to 1, if any. Otherwise, this is a
  53. // previously unseen value and all of them should be zero.
  54. auto iter = converter->find(feature_vector[i]);
  55. if (iter != converter->end())
  56. converted_feature_vector[iter->second] = FeatureValue(1);
  57. }
  58. return converted_feature_vector;
  59. }
  60. void OneHotConverter::ProcessOneFeature(
  61. size_t index,
  62. const LearningTask::ValueDescription& original_description,
  63. const TrainingData& training_data) {
  64. // Collect all the distinct values for |index|.
  65. std::set<Value> values;
  66. for (auto& example : training_data) {
  67. DCHECK_GE(example.features.size(), index);
  68. values.insert(example.features[index]);
  69. }
  70. // We let the set's ordering be the one-hot value. It doesn't really matter
  71. // as long as we don't change it once we pick it.
  72. ValueVectorIndexMap value_map;
  73. // Vector index that should be set to one for each distinct value. This will
  74. // start at the next feature in the adjusted task.
  75. size_t next_vector_index = converted_task_.feature_descriptions.size();
  76. // Add one feature for each value, and construct a map from value to the
  77. // feature index that should be 1 when the feature takes that value.
  78. for (auto& value : values) {
  79. LearningTask::ValueDescription adjusted_description = original_description;
  80. adjusted_description.ordering = LearningTask::Ordering::kNumeric;
  81. converted_task_.feature_descriptions.push_back(adjusted_description);
  82. // |value| will converted into a 1 in the |next_vector_index|-th feature.
  83. value_map[value] = next_vector_index++;
  84. }
  85. // Record |values| for the |index|-th original feature.
  86. converters_[index] = std::move(value_map);
  87. }
  88. ConvertingModel::ConvertingModel(std::unique_ptr<OneHotConverter> converter,
  89. std::unique_ptr<Model> model)
  90. : converter_(std::move(converter)), model_(std::move(model)) {}
  91. ConvertingModel::~ConvertingModel() = default;
  92. TargetHistogram ConvertingModel::PredictDistribution(
  93. const FeatureVector& instance) {
  94. FeatureVector converted_instance = converter_->Convert(instance);
  95. return model_->PredictDistribution(converted_instance);
  96. }
  97. } // namespace learning
  98. } // namespace media