prediction_model_override.cc 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. // Copyright 2022 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/optimization_guide/core/prediction_model_override.h"
  5. #include "base/files/file_util.h"
  6. #include "base/strings/string_util.h"
  7. #include "base/task/sequenced_task_runner.h"
  8. #include "base/task/thread_pool.h"
  9. #include "base/threading/thread_task_runner_handle.h"
  10. #include "components/optimization_guide/core/model_util.h"
  11. #include "components/optimization_guide/core/prediction_model_download_manager.h"
  12. #include "components/services/unzip/public/cpp/unzip.h"
  13. #if BUILDFLAG(IS_IOS)
  14. #include "components/services/unzip/in_process_unzipper.h" // nogncheck
  15. #else
  16. #include "components/services/unzip/content/unzip_service.h" // nogncheck
  17. #endif
  18. namespace optimization_guide {
  19. namespace {
  20. void OnModelOverrideProcessed(OnPredictionModelBuiltCallback callback,
  21. std::unique_ptr<proto::PredictionModel> model) {
  22. std::move(callback).Run(std::move(model));
  23. }
  24. std::unique_ptr<proto::PredictionModel> ProcessModelOverrideOnBGThread(
  25. proto::OptimizationTarget optimization_target,
  26. const base::FilePath& unzipped_dir_path) {
  27. // Unpack and verify model info file.
  28. base::FilePath model_info_path = unzipped_dir_path.Append(
  29. PredictionModelDownloadManager::ModelInfoFileName());
  30. std::string binary_model_info_pb;
  31. if (!base::ReadFileToString(model_info_path, &binary_model_info_pb)) {
  32. LOG(ERROR) << "Failed to read " << FilePathToString(model_info_path);
  33. return nullptr;
  34. }
  35. proto::ModelInfo model_info;
  36. if (!model_info.ParseFromString(binary_model_info_pb)) {
  37. LOG(ERROR) << "Failed to parse " << FilePathToString(model_info_path);
  38. return nullptr;
  39. }
  40. if (!model_info.has_version() || !model_info.has_optimization_target()) {
  41. LOG(ERROR) << FilePathToString(model_info_path)
  42. << "is invalid because it does not contain a version and/or "
  43. "optimization target";
  44. return nullptr;
  45. }
  46. for (int i = 0; i < model_info.additional_files_size(); i++) {
  47. proto::AdditionalModelFile* additional_file =
  48. model_info.mutable_additional_files(i);
  49. base::FilePath additional_file_basename =
  50. *StringToFilePath(additional_file->file_path());
  51. base::FilePath additional_file_absolute =
  52. unzipped_dir_path.Append(additional_file_basename);
  53. additional_file->set_file_path(FilePathToString(additional_file_absolute));
  54. }
  55. std::unique_ptr<proto::PredictionModel> model =
  56. std::make_unique<proto::PredictionModel>();
  57. *model->mutable_model_info() = model_info;
  58. model->mutable_model()->set_download_url(
  59. FilePathToString(unzipped_dir_path.Append(GetBaseFileNameForModels())));
  60. return model;
  61. }
  62. void OnModelOverrideUnzipped(proto::OptimizationTarget optimization_target,
  63. const base::FilePath& unzipped_dir_path,
  64. OnPredictionModelBuiltCallback callback,
  65. bool success) {
  66. if (!success) {
  67. LOG(ERROR) << FilePathToString(unzipped_dir_path) << "failed to unzip";
  68. std::move(callback).Run(nullptr);
  69. return;
  70. }
  71. base::ThreadPool::PostTaskAndReplyWithResult(
  72. FROM_HERE, {base::MayBlock(), base::TaskPriority::BEST_EFFORT},
  73. base::BindOnce(&ProcessModelOverrideOnBGThread, optimization_target,
  74. unzipped_dir_path),
  75. base::BindOnce(&OnModelOverrideProcessed, std::move(callback)));
  76. }
  77. void OnModelOverrideVerified(
  78. proto::OptimizationTarget optimization_target,
  79. const std::string& passed_crx_file_path,
  80. OnPredictionModelBuiltCallback callback,
  81. absl::optional<std::pair<base::FilePath, base::FilePath>> src_dst) {
  82. if (!src_dst) {
  83. LOG(ERROR) << passed_crx_file_path << " failed verification";
  84. std::move(callback).Run(nullptr);
  85. return;
  86. }
  87. #if BUILDFLAG(IS_IOS)
  88. auto unzipper = unzip::LaunchInProcessUnzipper();
  89. #else
  90. auto unzipper = unzip::LaunchUnzipper();
  91. #endif
  92. unzip::Unzip(std::move(unzipper), src_dst->first, src_dst->second,
  93. base::BindOnce(&OnModelOverrideUnzipped, optimization_target,
  94. src_dst->second, std::move(callback)));
  95. }
  96. } // namespace
  97. void BuildPredictionModelFromCommandLineForOptimizationTarget(
  98. proto::OptimizationTarget optimization_target,
  99. OnPredictionModelBuiltCallback callback) {
  100. absl::optional<std::pair<std::string, absl::optional<proto::Any>>>
  101. model_file_path_and_metadata =
  102. GetModelOverrideForOptimizationTarget(optimization_target);
  103. if (!model_file_path_and_metadata) {
  104. std::move(callback).Run(nullptr);
  105. return;
  106. }
  107. if (base::EndsWith(model_file_path_and_metadata->first, ".crx3")) {
  108. DVLOG(0) << "Attempting to parse the model override at "
  109. << model_file_path_and_metadata->first
  110. << " as a crx model package for "
  111. << GetStringNameForOptimizationTarget(optimization_target);
  112. if (model_file_path_and_metadata->second) {
  113. LOG(ERROR) << "Ignoring the metadata that was passed since a crx package "
  114. "was given";
  115. }
  116. base::ThreadPool::PostTaskAndReplyWithResult(
  117. FROM_HERE, {base::MayBlock(), base::TaskPriority::BEST_EFFORT},
  118. base::BindOnce(PredictionModelDownloadManager::VerifyDownload,
  119. *StringToFilePath(model_file_path_and_metadata->first),
  120. /*delete_file_on_error=*/false),
  121. base::BindOnce(&OnModelOverrideVerified, optimization_target,
  122. model_file_path_and_metadata->first,
  123. std::move(callback)));
  124. return;
  125. }
  126. std::unique_ptr<proto::PredictionModel> prediction_model =
  127. std::make_unique<proto::PredictionModel>();
  128. prediction_model->mutable_model_info()->set_optimization_target(
  129. optimization_target);
  130. prediction_model->mutable_model_info()->set_version(123);
  131. if (model_file_path_and_metadata->second) {
  132. *prediction_model->mutable_model_info()->mutable_model_metadata() =
  133. model_file_path_and_metadata->second.value();
  134. }
  135. prediction_model->mutable_model()->set_download_url(
  136. model_file_path_and_metadata->first);
  137. std::move(callback).Run(std::move(prediction_model));
  138. }
  139. } // namespace optimization_guide