ranker_model_loader_impl.cc 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. // Copyright 2017 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/assist_ranker/ranker_model_loader_impl.h"
  5. #include <utility>
  6. #include <memory>
  7. #include "base/bind.h"
  8. #include "base/callback_helpers.h"
  9. #include "base/command_line.h"
  10. #include "base/files/file_util.h"
  11. #include "base/files/important_file_writer.h"
  12. #include "base/metrics/histogram_macros.h"
  13. #include "base/strings/string_util.h"
  14. #include "base/task/sequenced_task_runner.h"
  15. #include "base/task/task_runner_util.h"
  16. #include "base/task/thread_pool.h"
  17. #include "base/threading/sequenced_task_runner_handle.h"
  18. #include "components/assist_ranker/proto/ranker_model.pb.h"
  19. #include "components/assist_ranker/ranker_url_fetcher.h"
  20. #include "services/network/public/cpp/shared_url_loader_factory.h"
  21. namespace assist_ranker {
  22. namespace {
  23. // The minimum duration, in minutes, between download attempts.
  24. constexpr int kMinRetryDelayMins = 3;
  25. // Suffixes for the various histograms produced by the backend.
  26. const char kWriteTimerHistogram[] = ".Timer.WriteModel";
  27. const char kReadTimerHistogram[] = ".Timer.ReadModel";
  28. const char kDownloadTimerHistogram[] = ".Timer.DownloadModel";
  29. const char kParsetimerHistogram[] = ".Timer.ParseModel";
  30. const char kModelStatusHistogram[] = ".Model.Status";
  31. // Helper function to UMA log a timer histograms.
  32. void RecordTimerHistogram(const std::string& name, base::TimeDelta duration) {
  33. base::HistogramBase* counter = base::Histogram::FactoryTimeGet(
  34. name, base::Milliseconds(10), base::Milliseconds(200000), 100,
  35. base::HistogramBase::kUmaTargetedHistogramFlag);
  36. DCHECK(counter);
  37. counter->AddTime(duration);
  38. }
  39. // A helper class to produce a scoped timer histogram that supports using a
  40. // non-static-const name.
  41. class MyScopedHistogramTimer {
  42. public:
  43. MyScopedHistogramTimer(const base::StringPiece& name)
  44. : name_(name.begin(), name.end()), start_(base::TimeTicks::Now()) {}
  45. MyScopedHistogramTimer(const MyScopedHistogramTimer&) = delete;
  46. MyScopedHistogramTimer& operator=(const MyScopedHistogramTimer&) = delete;
  47. ~MyScopedHistogramTimer() {
  48. RecordTimerHistogram(name_, base::TimeTicks::Now() - start_);
  49. }
  50. private:
  51. const std::string name_;
  52. const base::TimeTicks start_;
  53. };
  54. std::string LoadFromFile(const base::FilePath& model_path) {
  55. DCHECK(!model_path.empty());
  56. DVLOG(2) << "Reading data from: " << model_path.value();
  57. std::string data;
  58. if (!base::ReadFileToString(model_path, &data) || data.empty()) {
  59. DVLOG(2) << "Failed to read data from: " << model_path.value();
  60. data.clear();
  61. }
  62. return data;
  63. }
  64. void SaveToFile(const GURL& model_url,
  65. const base::FilePath& model_path,
  66. const std::string& model_data,
  67. const std::string& uma_prefix) {
  68. DVLOG(2) << "Saving model from '" << model_url << "'' to '"
  69. << model_path.value() << "'.";
  70. MyScopedHistogramTimer timer(uma_prefix + kWriteTimerHistogram);
  71. base::ImportantFileWriter::WriteFileAtomically(model_path, model_data);
  72. }
  73. } // namespace
  74. RankerModelLoaderImpl::RankerModelLoaderImpl(
  75. ValidateModelCallback validate_model_cb,
  76. OnModelAvailableCallback on_model_available_cb,
  77. scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
  78. base::FilePath model_path,
  79. GURL model_url,
  80. std::string uma_prefix)
  81. : background_task_runner_(base::ThreadPool::CreateSequencedTaskRunner(
  82. {base::MayBlock(), base::TaskPriority::BEST_EFFORT,
  83. base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN})),
  84. validate_model_cb_(std::move(validate_model_cb)),
  85. on_model_available_cb_(std::move(on_model_available_cb)),
  86. url_loader_factory_(std::move(url_loader_factory)),
  87. model_path_(std::move(model_path)),
  88. model_url_(std::move(model_url)),
  89. uma_prefix_(std::move(uma_prefix)),
  90. url_fetcher_(std::make_unique<RankerURLFetcher>()) {}
  91. RankerModelLoaderImpl::~RankerModelLoaderImpl() {
  92. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  93. }
  94. void RankerModelLoaderImpl::NotifyOfRankerActivity() {
  95. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  96. switch (state_) {
  97. case LoaderState::NOT_STARTED:
  98. if (!model_path_.empty()) {
  99. StartLoadFromFile();
  100. break;
  101. }
  102. // There was no configured model path. Switch the state to IDLE and
  103. // fall through to consider the URL.
  104. state_ = LoaderState::IDLE;
  105. [[fallthrough]];
  106. case LoaderState::IDLE:
  107. if (model_url_.is_valid()) {
  108. StartLoadFromURL();
  109. break;
  110. }
  111. // There was no configured model URL. Switch the state to FINISHED and
  112. // fall through.
  113. state_ = LoaderState::FINISHED;
  114. [[fallthrough]];
  115. case LoaderState::FINISHED:
  116. case LoaderState::LOADING_FROM_FILE:
  117. case LoaderState::LOADING_FROM_URL:
  118. // Nothing to do.
  119. break;
  120. }
  121. }
  122. void RankerModelLoaderImpl::StartLoadFromFile() {
  123. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  124. DCHECK_EQ(state_, LoaderState::NOT_STARTED);
  125. DCHECK(!model_path_.empty());
  126. state_ = LoaderState::LOADING_FROM_FILE;
  127. load_start_time_ = base::TimeTicks::Now();
  128. base::PostTaskAndReplyWithResult(
  129. background_task_runner_.get(), FROM_HERE,
  130. base::BindOnce(&LoadFromFile, model_path_),
  131. base::BindOnce(&RankerModelLoaderImpl::OnFileLoaded,
  132. weak_ptr_factory_.GetWeakPtr()));
  133. }
  134. void RankerModelLoaderImpl::OnFileLoaded(const std::string& data) {
  135. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  136. DCHECK_EQ(state_, LoaderState::LOADING_FROM_FILE);
  137. // Record the duration of the download.
  138. RecordTimerHistogram(uma_prefix_ + kReadTimerHistogram,
  139. base::TimeTicks::Now() - load_start_time_);
  140. // Empty data means |model_path| wasn't successfully read. Otherwise,
  141. // parse and validate the model.
  142. std::unique_ptr<RankerModel> model;
  143. if (data.empty()) {
  144. ReportModelStatus(RankerModelStatus::LOAD_FROM_CACHE_FAILED);
  145. } else {
  146. model = CreateAndValidateModel(data);
  147. }
  148. // If |model| is nullptr, then data is empty or the parse failed. Transition
  149. // to IDLE, from which URL download can be attempted.
  150. if (!model) {
  151. state_ = LoaderState::IDLE;
  152. } else {
  153. // The model is valid. The client is willing/able to use it. Keep track
  154. // of where it originated and whether or not is has expired.
  155. std::string url_spec = model->GetSourceURL();
  156. bool is_expired = model->IsExpired();
  157. bool is_finished = url_spec == model_url_.spec() && !is_expired;
  158. DVLOG(2) << (is_expired ? "Expired m" : "M") << "odel in '"
  159. << model_path_.value() << "' was originally downloaded from '"
  160. << url_spec << "'.";
  161. // If the cached model came from currently configured |model_url_| and has
  162. // not expired, transition to FINISHED, as there is no need for a model
  163. // download; otherwise, transition to IDLE.
  164. state_ = is_finished ? LoaderState::FINISHED : LoaderState::IDLE;
  165. // Transfer the model to the client.
  166. on_model_available_cb_.Run(std::move(model));
  167. }
  168. // Notify the state machine. This will immediately kick off a download if
  169. // one is required, instead of waiting for the next organic detection of
  170. // ranker activity.
  171. NotifyOfRankerActivity();
  172. }
  173. void RankerModelLoaderImpl::StartLoadFromURL() {
  174. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  175. DCHECK_EQ(state_, LoaderState::IDLE);
  176. DCHECK(model_url_.is_valid());
  177. // Do nothing if the download attempts should be throttled.
  178. if (base::TimeTicks::Now() < next_earliest_download_time_) {
  179. DVLOG(2) << "Last download attempt was too recent.";
  180. ReportModelStatus(RankerModelStatus::DOWNLOAD_THROTTLED);
  181. return;
  182. }
  183. // Kick off the next download attempt and reset the time of the next earliest
  184. // allowable download attempt.
  185. DVLOG(2) << "Downloading model from: " << model_url_;
  186. state_ = LoaderState::LOADING_FROM_URL;
  187. load_start_time_ = base::TimeTicks::Now();
  188. next_earliest_download_time_ =
  189. load_start_time_ + base::Minutes(kMinRetryDelayMins);
  190. bool request_started =
  191. url_fetcher_->Request(model_url_,
  192. base::BindOnce(&RankerModelLoaderImpl::OnURLFetched,
  193. weak_ptr_factory_.GetWeakPtr()),
  194. url_loader_factory_.get());
  195. // |url_fetcher_| maintains a request retry counter. If all allowed attempts
  196. // have already been exhausted, then the loader is finished and has abandoned
  197. // loading the model.
  198. if (!request_started) {
  199. DVLOG(2) << "Model download abandoned.";
  200. ReportModelStatus(RankerModelStatus::MODEL_LOADING_ABANDONED);
  201. state_ = LoaderState::FINISHED;
  202. }
  203. }
  204. void RankerModelLoaderImpl::OnURLFetched(bool success,
  205. const std::string& data) {
  206. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  207. DCHECK_EQ(state_, LoaderState::LOADING_FROM_URL);
  208. // Record the duration of the download.
  209. RecordTimerHistogram(uma_prefix_ + kDownloadTimerHistogram,
  210. base::TimeTicks::Now() - load_start_time_);
  211. // On request failure, transition back to IDLE. The loader will retry, or
  212. // enforce the max download attempts, later.
  213. if (!success || data.empty()) {
  214. DVLOG(2) << "Download from '" << model_url_ << "'' failed.";
  215. ReportModelStatus(RankerModelStatus::DOWNLOAD_FAILED);
  216. state_ = LoaderState::IDLE;
  217. return;
  218. }
  219. // Attempt to loads the model. If this fails, transition back to IDLE. The
  220. // loader will retry, or enfore the max download attempts, later.
  221. auto model = CreateAndValidateModel(data);
  222. if (!model) {
  223. DVLOG(2) << "Model from '" << model_url_ << "'' not valid.";
  224. state_ = LoaderState::IDLE;
  225. return;
  226. }
  227. // The model is valid. Update the metadata to track the source URL and
  228. // download timestamp.
  229. auto* metadata = model->mutable_proto()->mutable_metadata();
  230. metadata->set_source(model_url_.spec());
  231. metadata->set_last_modified_sec(
  232. (base::Time::Now() - base::Time()).InSeconds());
  233. // Cache the model to model_path_, in the background.
  234. if (!model_path_.empty()) {
  235. background_task_runner_->PostTask(
  236. FROM_HERE, base::BindOnce(&SaveToFile, model_url_, model_path_,
  237. model->SerializeAsString(), uma_prefix_));
  238. }
  239. // The loader is finished. Transfer the model to the client.
  240. state_ = LoaderState::FINISHED;
  241. on_model_available_cb_.Run(std::move(model));
  242. }
  243. std::unique_ptr<RankerModel> RankerModelLoaderImpl::CreateAndValidateModel(
  244. const std::string& data) {
  245. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  246. MyScopedHistogramTimer timer(uma_prefix_ + kParsetimerHistogram);
  247. auto model = RankerModel::FromString(data);
  248. if (ReportModelStatus(model ? validate_model_cb_.Run(*model)
  249. : RankerModelStatus::PARSE_FAILED) !=
  250. RankerModelStatus::OK) {
  251. return nullptr;
  252. }
  253. return model;
  254. }
  255. RankerModelStatus RankerModelLoaderImpl::ReportModelStatus(
  256. RankerModelStatus model_status) {
  257. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  258. base::HistogramBase* histogram = base::LinearHistogram::FactoryGet(
  259. uma_prefix_ + kModelStatusHistogram, 1,
  260. static_cast<int>(RankerModelStatus::MAX),
  261. static_cast<int>(RankerModelStatus::MAX) + 1,
  262. base::HistogramBase::kUmaTargetedHistogramFlag);
  263. if (histogram)
  264. histogram->Add(static_cast<int>(model_status));
  265. return model_status;
  266. }
  267. } // namespace assist_ranker