123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173 |
- // Copyright 2021 The Chromium Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style license that can be
- // found in the LICENSE file.
- #include "components/optimization_guide/core/model_util.h"
- #include "base/base64.h"
- #include "base/containers/flat_set.h"
- #include "base/logging.h"
- #include "base/notreached.h"
- #include "base/strings/string_split.h"
- #include "base/strings/utf_string_conversions.h"
- #include "build/build_config.h"
- #include "components/optimization_guide/core/optimization_guide_switches.h"
- #include "net/base/url_util.h"
- #include "url/url_canon.h"
- namespace optimization_guide {
- namespace {
- // The ":" character is reserved in Windows as part of an absolute file path,
- // e.g.: C:\model.tflite, so we use a different separtor.
- #if BUILDFLAG(IS_WIN)
- const char kModelOverrideSeparator[] = "|";
- #else
- const char kModelOverrideSeparator[] = ":";
- #endif
- } // namespace
- // These names are persisted to histograms, so don't change them.
- std::string GetStringNameForOptimizationTarget(
- optimization_guide::proto::OptimizationTarget optimization_target) {
- switch (optimization_target) {
- case proto::OPTIMIZATION_TARGET_UNKNOWN:
- return "Unknown";
- case proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD:
- return "PainfulPageLoad";
- case proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION:
- return "LanguageDetection";
- case proto::OPTIMIZATION_TARGET_PAGE_TOPICS:
- return "PageTopics";
- case proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
- return "SegmentationNewTab";
- case proto::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
- return "SegmentationShare";
- case proto::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
- return "SegmentationVoice";
- case proto::OPTIMIZATION_TARGET_MODEL_VALIDATION:
- return "ModelValidation";
- case proto::OPTIMIZATION_TARGET_PAGE_ENTITIES:
- return "PageEntities";
- case proto::OPTIMIZATION_TARGET_NOTIFICATION_PERMISSION_PREDICTIONS:
- return "NotificationPermissions";
- case proto::OPTIMIZATION_TARGET_SEGMENTATION_DUMMY:
- return "SegmentationDummyFeature";
- case proto::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID:
- return "SegmentationChromeStartAndroid";
- case proto::OPTIMIZATION_TARGET_SEGMENTATION_QUERY_TILES:
- return "SegmentationQueryTiles";
- case proto::OPTIMIZATION_TARGET_PAGE_VISIBILITY:
- return "PageVisibility";
- case proto::OPTIMIZATION_TARGET_AUTOFILL_ASSISTANT:
- return "AutofillAssistant";
- case proto::OPTIMIZATION_TARGET_PAGE_TOPICS_V2:
- return "PageTopicsV2";
- case proto::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_LOW_USER_ENGAGEMENT:
- return "SegmentationChromeLowUserEngagement";
- case proto::OPTIMIZATION_TARGET_SEGMENTATION_FEED_USER:
- return "SegmentationFeedUser";
- case proto::OPTIMIZATION_TARGET_CONTEXTUAL_PAGE_ACTION_PRICE_TRACKING:
- return "ContextualPageActionPriceTracking";
- case proto::OPTIMIZATION_TARGET_TEXT_CLASSIFIER:
- return "TextClassifier";
- // Whenever a new value is added, make sure to add it to the OptTarget
- // variant list in
- // //tools/metrics/histograms/metadata/optimization/histograms.xml.
- }
- NOTREACHED();
- return std::string();
- }
- absl::optional<base::FilePath> StringToFilePath(const std::string& str_path) {
- if (str_path.empty())
- return absl::nullopt;
- #if BUILDFLAG(IS_WIN)
- return base::FilePath(base::UTF8ToWide(str_path));
- #else
- return base::FilePath(str_path);
- #endif
- }
- std::string FilePathToString(const base::FilePath& file_path) {
- #if BUILDFLAG(IS_WIN)
- return base::WideToUTF8(file_path.value());
- #else
- return file_path.value();
- #endif
- }
- base::FilePath GetBaseFileNameForModels() {
- return base::FilePath(FILE_PATH_LITERAL("model.tflite"));
- }
- std::string ModelOverrideSeparator() {
- return kModelOverrideSeparator;
- }
- absl::optional<
- std::pair<std::string, absl::optional<optimization_guide::proto::Any>>>
- GetModelOverrideForOptimizationTarget(
- optimization_guide::proto::OptimizationTarget optimization_target) {
- auto model_override_switch_value = switches::GetModelOverride();
- if (!model_override_switch_value)
- return absl::nullopt;
- std::vector<std::string> model_overrides =
- base::SplitString(*model_override_switch_value, ",",
- base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY);
- for (const auto& model_override : model_overrides) {
- std::vector<std::string> override_parts =
- base::SplitString(model_override, kModelOverrideSeparator,
- base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY);
- if (override_parts.size() != 2 && override_parts.size() != 3) {
- // Input is malformed.
- DLOG(ERROR) << "Invalid string format provided to the Model Override";
- return absl::nullopt;
- }
- optimization_guide::proto::OptimizationTarget recv_optimization_target;
- if (!optimization_guide::proto::OptimizationTarget_Parse(
- override_parts[0], &recv_optimization_target)) {
- // Optimization target is invalid.
- DLOG(ERROR)
- << "Invalid optimization target provided to the Model Override";
- return absl::nullopt;
- }
- if (optimization_target != recv_optimization_target)
- continue;
- std::string file_name = override_parts[1];
- base::FilePath file_path = *StringToFilePath(file_name);
- if (!file_path.IsAbsolute()) {
- DLOG(ERROR) << "Provided model file path must be absolute " << file_name;
- return absl::nullopt;
- }
- if (override_parts.size() == 2) {
- std::pair<std::string, absl::optional<optimization_guide::proto::Any>>
- file_path_and_metadata = std::make_pair(file_name, absl::nullopt);
- return file_path_and_metadata;
- }
- std::string binary_pb;
- if (!base::Base64Decode(override_parts[2], &binary_pb)) {
- DLOG(ERROR) << "Invalid base64 encoding of the Model Override";
- return absl::nullopt;
- }
- optimization_guide::proto::Any model_metadata;
- if (!model_metadata.ParseFromString(binary_pb)) {
- DLOG(ERROR) << "Invalid model metadata provided to the Model Override";
- return absl::nullopt;
- }
- std::pair<std::string, absl::optional<optimization_guide::proto::Any>>
- file_path_and_metadata = std::make_pair(file_name, model_metadata);
- return file_path_and_metadata;
- }
- return absl::nullopt;
- }
- } // namespace optimization_guide
|