example_preprocessing.cc 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. // Copyright 2018 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/assist_ranker/example_preprocessing.h"
  5. #include "base/cxx17_backports.h"
  6. #include "base/logging.h"
  7. #include "base/strings/strcat.h"
  8. #include "base/strings/string_number_conversions.h"
  9. #include "components/assist_ranker/ranker_example_util.h"
  10. #include "third_party/protobuf/src/google/protobuf/map.h"
  11. #include "third_party/protobuf/src/google/protobuf/repeated_field.h"
  12. namespace assist_ranker {
  13. using google::protobuf::Map;
  14. using google::protobuf::MapPair;
  15. using google::protobuf::RepeatedField;
  16. // Initialize.
  17. const char ExamplePreprocessor::kMissingFeatureDefaultName[] =
  18. "_MissingFeature";
  19. const char ExamplePreprocessor::kVectorizedFeatureDefaultName[] =
  20. "_VectorizedFeature";
  21. std::string ExamplePreprocessor::FeatureFullname(
  22. const std::string& feature_name,
  23. const std::string& feature_value) {
  24. return feature_value.empty()
  25. ? feature_name
  26. : base::StrCat({feature_name, "_", feature_value});
  27. }
  28. int ExamplePreprocessor::Process(const ExamplePreprocessorConfig& config,
  29. RankerExample* const example,
  30. const bool clear_other_features) {
  31. return AddMissingFeatures(config, example) |
  32. NormalizeFeatures(config, example) |
  33. AddBucketizedFeatures(config, example) |
  34. ConvertToStringFeatures(config, example) |
  35. Vectorization(config, example, clear_other_features);
  36. }
  37. int ExamplePreprocessor::AddMissingFeatures(
  38. const ExamplePreprocessorConfig& config,
  39. RankerExample* const example) {
  40. Map<std::string, Feature>& feature_map = *example->mutable_features();
  41. for (const std::string& feature_name : config.missing_features()) {
  42. // If a feature is missing in the example, set the place.
  43. if (feature_map.find(feature_name) == feature_map.end()) {
  44. feature_map[kMissingFeatureDefaultName]
  45. .mutable_string_list()
  46. ->add_string_value(feature_name);
  47. }
  48. }
  49. return kSuccess;
  50. }
  51. int ExamplePreprocessor::AddBucketizedFeatures(
  52. const ExamplePreprocessorConfig& config,
  53. RankerExample* const example) {
  54. int error_code = kSuccess;
  55. Map<std::string, Feature>& feature_map = *example->mutable_features();
  56. for (const MapPair<std::string, ExamplePreprocessorConfig::Boundaries>&
  57. bucketizer : config.bucketizers()) {
  58. const std::string& feature_name = bucketizer.first;
  59. // Simply continue if the feature is missing. The missing feature will later
  60. // on be handled as missing one_hot feature, and it's up to the user how to
  61. // handle this missing feature.
  62. Feature feature;
  63. if (!SafeGetFeature(feature_name, *example, &feature)) {
  64. continue;
  65. }
  66. // Get feature value as float. Only int32 or float value is supported for
  67. // Bucketization. Continue if the type_case is not int32 or float.
  68. float value = 0;
  69. switch (feature.feature_type_case()) {
  70. case Feature::kInt32Value:
  71. value = static_cast<float>(feature.int32_value());
  72. break;
  73. case Feature::kFloatValue:
  74. value = feature.float_value();
  75. break;
  76. default:
  77. DVLOG(2) << "Can't bucketize feature type: "
  78. << feature.feature_type_case();
  79. error_code |= kNonbucketizableFeatureType;
  80. continue;
  81. }
  82. // Get the bucket from the boundaries; the first index that value<boundary.
  83. const RepeatedField<float>& boundaries = bucketizer.second.boundaries();
  84. int index = 0;
  85. for (; index < boundaries.size(); ++index) {
  86. if (value < boundaries[index])
  87. break;
  88. }
  89. // Set one hot feature as features[feature_name] = "index";
  90. feature_map[feature_name].set_string_value(base::NumberToString(index));
  91. }
  92. return error_code;
  93. }
  94. int ExamplePreprocessor::NormalizeFeatures(
  95. const ExamplePreprocessorConfig& config,
  96. RankerExample* example) {
  97. int error_code = kSuccess;
  98. for (const MapPair<std::string, float>& pair : config.normalizers()) {
  99. const std::string& feature_name = pair.first;
  100. float feature_value = 0.0f;
  101. if (GetFeatureValueAsFloat(feature_name, *example, &feature_value)) {
  102. if (pair.second == 0.0f) {
  103. error_code |= kNormalizerIsZero;
  104. } else {
  105. feature_value = feature_value / pair.second;
  106. }
  107. // Truncate to be within [-1.0, 1.0].
  108. feature_value = base::clamp(feature_value, -1.0f, 1.0f);
  109. (*example->mutable_features())[feature_name].set_float_value(
  110. feature_value);
  111. } else {
  112. error_code |= kNonNormalizableFeatureType;
  113. }
  114. }
  115. return error_code;
  116. }
  117. int ExamplePreprocessor::ConvertToStringFeatures(
  118. const ExamplePreprocessorConfig& config,
  119. RankerExample* example) {
  120. int error_code = kSuccess;
  121. for (const std::string& feature_name : config.convert_to_string_features()) {
  122. const auto find_feature = example->mutable_features()->find(feature_name);
  123. if (find_feature != example->features().end()) {
  124. auto& feature = find_feature->second;
  125. switch (feature.feature_type_case()) {
  126. case Feature::kBoolValue:
  127. feature.set_string_value(
  128. base::NumberToString(static_cast<int>(feature.bool_value())));
  129. break;
  130. case Feature::kInt32Value:
  131. feature.set_string_value(base::NumberToString(feature.int32_value()));
  132. break;
  133. case Feature::kStringValue:
  134. break;
  135. default:
  136. LOG(WARNING) << "Can't convert to string feature type: "
  137. << feature.feature_type_case();
  138. error_code |= kNonConvertibleToStringFeatureType;
  139. continue;
  140. }
  141. }
  142. }
  143. return error_code;
  144. }
  145. int ExamplePreprocessor::Vectorization(const ExamplePreprocessorConfig& config,
  146. RankerExample* example,
  147. const bool clear_other_features) {
  148. if (config.feature_indices().empty()) {
  149. DVLOG(2) << "Feature indices are empty, can't vectorize.";
  150. return kSuccess;
  151. }
  152. Feature vectorized_features;
  153. vectorized_features.mutable_float_list()->mutable_float_value()->Resize(
  154. config.feature_indices().size(), 0.0);
  155. int error_code = kSuccess;
  156. for (const auto& field : ExampleFloatIterator(*example)) {
  157. error_code |= field.error;
  158. if (field.error != kSuccess) {
  159. continue;
  160. }
  161. const auto find_index = config.feature_indices().find(field.fullname);
  162. // If the feature_fullname is inside the indices map, then set the place.
  163. if (find_index != config.feature_indices().end()) {
  164. vectorized_features.mutable_float_list()->set_float_value(
  165. find_index->second, field.value);
  166. } else {
  167. DVLOG(2) << "Feature has no index: " << field.fullname;
  168. error_code |= kNoFeatureIndexFound;
  169. }
  170. }
  171. if (clear_other_features) {
  172. example->clear_features();
  173. }
  174. (*example->mutable_features())[kVectorizedFeatureDefaultName] =
  175. vectorized_features;
  176. return error_code;
  177. }
  178. ExampleFloatIterator::Field ExampleFloatIterator::operator*() const {
  179. const std::string& feature_name = feature_iterator_->first;
  180. const Feature& feature = feature_iterator_->second;
  181. Field field = {feature_name, 1.0f, ExamplePreprocessor::kSuccess};
  182. switch (feature.feature_type_case()) {
  183. case Feature::kBoolValue:
  184. field.value = static_cast<float>(feature.bool_value());
  185. break;
  186. case Feature::kInt32Value:
  187. field.value = static_cast<float>(feature.int32_value());
  188. break;
  189. case Feature::kFloatValue:
  190. field.value = feature.float_value();
  191. break;
  192. case Feature::kStringValue:
  193. field.fullname = ExamplePreprocessor::FeatureFullname(
  194. feature_name, feature.string_value());
  195. break;
  196. case Feature::kStringList:
  197. if (string_list_index_ < feature.string_list().string_value_size()) {
  198. const std::string& string_value =
  199. feature.string_list().string_value(string_list_index_);
  200. field.fullname =
  201. ExamplePreprocessor::FeatureFullname(feature_name, string_value);
  202. } else {
  203. // This happens when a string list field is added without any value.
  204. field.error = ExamplePreprocessor::kInvalidFeatureListIndex;
  205. }
  206. break;
  207. default:
  208. field.error = ExamplePreprocessor::kInvalidFeatureType;
  209. DVLOG(2) << "Feature type not supported: "
  210. << feature.feature_type_case();
  211. break;
  212. }
  213. return field;
  214. }
  215. ExampleFloatIterator& ExampleFloatIterator::operator++() {
  216. const Feature& feature = feature_iterator_->second;
  217. switch (feature.feature_type_case()) {
  218. case Feature::kBoolValue:
  219. case Feature::kInt32Value:
  220. case Feature::kFloatValue:
  221. case Feature::kStringValue:
  222. ++feature_iterator_;
  223. break;
  224. case Feature::kStringList:
  225. if (string_list_index_ < feature.string_list().string_value_size() - 1) {
  226. // If not at the last element, advance the index.
  227. ++string_list_index_;
  228. } else {
  229. // If at the last element, advance the feature_iterator.
  230. string_list_index_ = 0;
  231. ++feature_iterator_;
  232. }
  233. break;
  234. default:
  235. ++feature_iterator_;
  236. DVLOG(2) << "Feature type not supported: "
  237. << feature.feature_type_case();
  238. }
  239. return *this;
  240. }
  241. } // namespace assist_ranker