base_predictor.cc 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. // Copyright 2017 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/assist_ranker/base_predictor.h"
  5. #include "base/containers/contains.h"
  6. #include "base/feature_list.h"
  7. #include "components/assist_ranker/proto/ranker_example.pb.h"
  8. #include "components/assist_ranker/proto/ranker_model.pb.h"
  9. #include "components/assist_ranker/ranker_example_util.h"
  10. #include "components/assist_ranker/ranker_model.h"
  11. #include "services/metrics/public/cpp/ukm_entry_builder.h"
  12. #include "services/metrics/public/cpp/ukm_recorder.h"
  13. #include "url/gurl.h"
  14. namespace assist_ranker {
  15. BasePredictor::BasePredictor(const PredictorConfig& config) : config_(config) {
  16. // TODO(chrome-ranker-team): validate config.
  17. if (config_.field_trial) {
  18. is_query_enabled_ = base::FeatureList::IsEnabled(*config_.field_trial);
  19. } else {
  20. DVLOG(0) << "No field trial specified";
  21. }
  22. }
  23. BasePredictor::~BasePredictor() {}
  24. void BasePredictor::LoadModel(std::unique_ptr<RankerModelLoader> model_loader) {
  25. if (!is_query_enabled_)
  26. return;
  27. if (model_loader_) {
  28. DVLOG(0) << "This predictor already has a model loader.";
  29. return;
  30. }
  31. // Take ownership of the model loader.
  32. model_loader_ = std::move(model_loader);
  33. // Kick off the initial model load.
  34. model_loader_->NotifyOfRankerActivity();
  35. }
  36. void BasePredictor::OnModelAvailable(
  37. std::unique_ptr<assist_ranker::RankerModel> model) {
  38. ranker_model_ = std::move(model);
  39. is_ready_ = Initialize();
  40. }
  41. bool BasePredictor::IsReady() {
  42. if (!is_ready_ && model_loader_)
  43. model_loader_->NotifyOfRankerActivity();
  44. return is_ready_;
  45. }
  46. void BasePredictor::LogFeatureToUkm(const std::string& feature_name,
  47. const Feature& feature,
  48. ukm::UkmEntryBuilder* ukm_builder) {
  49. DCHECK(ukm_builder);
  50. if (!base::Contains(*config_.feature_allowlist, feature_name)) {
  51. DVLOG(1) << "Feature not allowed: " << feature_name;
  52. return;
  53. }
  54. switch (feature.feature_type_case()) {
  55. case Feature::kBoolValue:
  56. case Feature::kFloatValue:
  57. case Feature::kInt32Value:
  58. case Feature::kStringValue: {
  59. int64_t feature_int64_value = -1;
  60. FeatureToInt64(feature, &feature_int64_value);
  61. DVLOG(3) << "Logging: " << feature_name << ": " << feature_int64_value;
  62. ukm_builder->SetMetric(feature_name, feature_int64_value);
  63. break;
  64. }
  65. case Feature::kStringList: {
  66. for (int i = 0; i < feature.string_list().string_value_size(); ++i) {
  67. int64_t feature_int64_value = -1;
  68. FeatureToInt64(feature, &feature_int64_value, i);
  69. DVLOG(3) << "Logging: " << feature_name << ": " << feature_int64_value;
  70. ukm_builder->SetMetric(feature_name, feature_int64_value);
  71. }
  72. break;
  73. }
  74. default:
  75. DVLOG(0) << "Could not convert feature to int: " << feature_name;
  76. }
  77. }
  78. void BasePredictor::LogExampleToUkm(const RankerExample& example,
  79. ukm::SourceId source_id) {
  80. if (config_.log_type != LOG_UKM) {
  81. DVLOG(0) << "Wrong log type in predictor config: " << config_.log_type;
  82. return;
  83. }
  84. if (!config_.feature_allowlist) {
  85. DVLOG(0) << "No allowlist specified.";
  86. return;
  87. }
  88. if (config_.feature_allowlist->empty()) {
  89. DVLOG(0) << "Empty allowlist, examples will not be logged.";
  90. return;
  91. }
  92. ukm::UkmEntryBuilder builder(source_id, config_.logging_name);
  93. for (const auto& feature_kv : example.features()) {
  94. LogFeatureToUkm(feature_kv.first, feature_kv.second, &builder);
  95. }
  96. builder.Record(ukm::UkmRecorder::Get());
  97. }
  98. std::string BasePredictor::GetModelName() const {
  99. return config_.model_name;
  100. }
  101. GURL BasePredictor::GetModelUrl() const {
  102. if (!config_.field_trial_url_param) {
  103. DVLOG(1) << "No URL specified.";
  104. return GURL();
  105. }
  106. return GURL(config_.field_trial_url_param->Get());
  107. }
  108. float BasePredictor::GetPredictThresholdReplacement() const {
  109. return config_.field_trial_threshold_replacement_param;
  110. }
  111. RankerExample BasePredictor::PreprocessExample(const RankerExample& example) {
  112. if (ranker_model_->proto().has_metadata() &&
  113. ranker_model_->proto().metadata().input_features_names_are_hex_hashes()) {
  114. return HashExampleFeatureNames(example);
  115. }
  116. return example;
  117. }
  118. } // namespace assist_ranker