model_util.cc 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. // Copyright 2021 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/optimization_guide/core/model_util.h"
  5. #include "base/base64.h"
  6. #include "base/containers/flat_set.h"
  7. #include "base/logging.h"
  8. #include "base/notreached.h"
  9. #include "base/strings/string_split.h"
  10. #include "base/strings/utf_string_conversions.h"
  11. #include "build/build_config.h"
  12. #include "components/optimization_guide/core/optimization_guide_switches.h"
  13. #include "net/base/url_util.h"
  14. #include "url/url_canon.h"
  15. namespace optimization_guide {
  16. namespace {
  17. // The ":" character is reserved in Windows as part of an absolute file path,
  18. // e.g.: C:\model.tflite, so we use a different separtor.
  19. #if BUILDFLAG(IS_WIN)
  20. const char kModelOverrideSeparator[] = "|";
  21. #else
  22. const char kModelOverrideSeparator[] = ":";
  23. #endif
  24. } // namespace
  25. // These names are persisted to histograms, so don't change them.
  26. std::string GetStringNameForOptimizationTarget(
  27. optimization_guide::proto::OptimizationTarget optimization_target) {
  28. switch (optimization_target) {
  29. case proto::OPTIMIZATION_TARGET_UNKNOWN:
  30. return "Unknown";
  31. case proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD:
  32. return "PainfulPageLoad";
  33. case proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION:
  34. return "LanguageDetection";
  35. case proto::OPTIMIZATION_TARGET_PAGE_TOPICS:
  36. return "PageTopics";
  37. case proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
  38. return "SegmentationNewTab";
  39. case proto::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
  40. return "SegmentationShare";
  41. case proto::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
  42. return "SegmentationVoice";
  43. case proto::OPTIMIZATION_TARGET_MODEL_VALIDATION:
  44. return "ModelValidation";
  45. case proto::OPTIMIZATION_TARGET_PAGE_ENTITIES:
  46. return "PageEntities";
  47. case proto::OPTIMIZATION_TARGET_NOTIFICATION_PERMISSION_PREDICTIONS:
  48. return "NotificationPermissions";
  49. case proto::OPTIMIZATION_TARGET_SEGMENTATION_DUMMY:
  50. return "SegmentationDummyFeature";
  51. case proto::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID:
  52. return "SegmentationChromeStartAndroid";
  53. case proto::OPTIMIZATION_TARGET_SEGMENTATION_QUERY_TILES:
  54. return "SegmentationQueryTiles";
  55. case proto::OPTIMIZATION_TARGET_PAGE_VISIBILITY:
  56. return "PageVisibility";
  57. case proto::OPTIMIZATION_TARGET_AUTOFILL_ASSISTANT:
  58. return "AutofillAssistant";
  59. case proto::OPTIMIZATION_TARGET_PAGE_TOPICS_V2:
  60. return "PageTopicsV2";
  61. case proto::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_LOW_USER_ENGAGEMENT:
  62. return "SegmentationChromeLowUserEngagement";
  63. case proto::OPTIMIZATION_TARGET_SEGMENTATION_FEED_USER:
  64. return "SegmentationFeedUser";
  65. case proto::OPTIMIZATION_TARGET_CONTEXTUAL_PAGE_ACTION_PRICE_TRACKING:
  66. return "ContextualPageActionPriceTracking";
  67. case proto::OPTIMIZATION_TARGET_TEXT_CLASSIFIER:
  68. return "TextClassifier";
  69. // Whenever a new value is added, make sure to add it to the OptTarget
  70. // variant list in
  71. // //tools/metrics/histograms/metadata/optimization/histograms.xml.
  72. }
  73. NOTREACHED();
  74. return std::string();
  75. }
  76. absl::optional<base::FilePath> StringToFilePath(const std::string& str_path) {
  77. if (str_path.empty())
  78. return absl::nullopt;
  79. #if BUILDFLAG(IS_WIN)
  80. return base::FilePath(base::UTF8ToWide(str_path));
  81. #else
  82. return base::FilePath(str_path);
  83. #endif
  84. }
  85. std::string FilePathToString(const base::FilePath& file_path) {
  86. #if BUILDFLAG(IS_WIN)
  87. return base::WideToUTF8(file_path.value());
  88. #else
  89. return file_path.value();
  90. #endif
  91. }
  92. base::FilePath GetBaseFileNameForModels() {
  93. return base::FilePath(FILE_PATH_LITERAL("model.tflite"));
  94. }
  95. std::string ModelOverrideSeparator() {
  96. return kModelOverrideSeparator;
  97. }
  98. absl::optional<
  99. std::pair<std::string, absl::optional<optimization_guide::proto::Any>>>
  100. GetModelOverrideForOptimizationTarget(
  101. optimization_guide::proto::OptimizationTarget optimization_target) {
  102. auto model_override_switch_value = switches::GetModelOverride();
  103. if (!model_override_switch_value)
  104. return absl::nullopt;
  105. std::vector<std::string> model_overrides =
  106. base::SplitString(*model_override_switch_value, ",",
  107. base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY);
  108. for (const auto& model_override : model_overrides) {
  109. std::vector<std::string> override_parts =
  110. base::SplitString(model_override, kModelOverrideSeparator,
  111. base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY);
  112. if (override_parts.size() != 2 && override_parts.size() != 3) {
  113. // Input is malformed.
  114. DLOG(ERROR) << "Invalid string format provided to the Model Override";
  115. return absl::nullopt;
  116. }
  117. optimization_guide::proto::OptimizationTarget recv_optimization_target;
  118. if (!optimization_guide::proto::OptimizationTarget_Parse(
  119. override_parts[0], &recv_optimization_target)) {
  120. // Optimization target is invalid.
  121. DLOG(ERROR)
  122. << "Invalid optimization target provided to the Model Override";
  123. return absl::nullopt;
  124. }
  125. if (optimization_target != recv_optimization_target)
  126. continue;
  127. std::string file_name = override_parts[1];
  128. base::FilePath file_path = *StringToFilePath(file_name);
  129. if (!file_path.IsAbsolute()) {
  130. DLOG(ERROR) << "Provided model file path must be absolute " << file_name;
  131. return absl::nullopt;
  132. }
  133. if (override_parts.size() == 2) {
  134. std::pair<std::string, absl::optional<optimization_guide::proto::Any>>
  135. file_path_and_metadata = std::make_pair(file_name, absl::nullopt);
  136. return file_path_and_metadata;
  137. }
  138. std::string binary_pb;
  139. if (!base::Base64Decode(override_parts[2], &binary_pb)) {
  140. DLOG(ERROR) << "Invalid base64 encoding of the Model Override";
  141. return absl::nullopt;
  142. }
  143. optimization_guide::proto::Any model_metadata;
  144. if (!model_metadata.ParseFromString(binary_pb)) {
  145. DLOG(ERROR) << "Invalid model metadata provided to the Model Override";
  146. return absl::nullopt;
  147. }
  148. std::pair<std::string, absl::optional<optimization_guide::proto::Any>>
  149. file_path_and_metadata = std::make_pair(file_name, model_metadata);
  150. return file_path_and_metadata;
  151. }
  152. return absl::nullopt;
  153. }
  154. } // namespace optimization_guide