123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308 |
- // Copyright 2017 The Chromium Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style license that can be
- // found in the LICENSE file.
- #include "components/assist_ranker/ranker_model_loader_impl.h"
- #include <utility>
- #include <memory>
- #include "base/bind.h"
- #include "base/callback_helpers.h"
- #include "base/command_line.h"
- #include "base/files/file_util.h"
- #include "base/files/important_file_writer.h"
- #include "base/metrics/histogram_macros.h"
- #include "base/strings/string_util.h"
- #include "base/task/sequenced_task_runner.h"
- #include "base/task/task_runner_util.h"
- #include "base/task/thread_pool.h"
- #include "base/threading/sequenced_task_runner_handle.h"
- #include "components/assist_ranker/proto/ranker_model.pb.h"
- #include "components/assist_ranker/ranker_url_fetcher.h"
- #include "services/network/public/cpp/shared_url_loader_factory.h"
- namespace assist_ranker {
- namespace {
- // The minimum duration, in minutes, between download attempts.
- constexpr int kMinRetryDelayMins = 3;
- // Suffixes for the various histograms produced by the backend.
- const char kWriteTimerHistogram[] = ".Timer.WriteModel";
- const char kReadTimerHistogram[] = ".Timer.ReadModel";
- const char kDownloadTimerHistogram[] = ".Timer.DownloadModel";
- const char kParsetimerHistogram[] = ".Timer.ParseModel";
- const char kModelStatusHistogram[] = ".Model.Status";
- // Helper function to UMA log a timer histograms.
- void RecordTimerHistogram(const std::string& name, base::TimeDelta duration) {
- base::HistogramBase* counter = base::Histogram::FactoryTimeGet(
- name, base::Milliseconds(10), base::Milliseconds(200000), 100,
- base::HistogramBase::kUmaTargetedHistogramFlag);
- DCHECK(counter);
- counter->AddTime(duration);
- }
- // A helper class to produce a scoped timer histogram that supports using a
- // non-static-const name.
- class MyScopedHistogramTimer {
- public:
- MyScopedHistogramTimer(const base::StringPiece& name)
- : name_(name.begin(), name.end()), start_(base::TimeTicks::Now()) {}
- MyScopedHistogramTimer(const MyScopedHistogramTimer&) = delete;
- MyScopedHistogramTimer& operator=(const MyScopedHistogramTimer&) = delete;
- ~MyScopedHistogramTimer() {
- RecordTimerHistogram(name_, base::TimeTicks::Now() - start_);
- }
- private:
- const std::string name_;
- const base::TimeTicks start_;
- };
- std::string LoadFromFile(const base::FilePath& model_path) {
- DCHECK(!model_path.empty());
- DVLOG(2) << "Reading data from: " << model_path.value();
- std::string data;
- if (!base::ReadFileToString(model_path, &data) || data.empty()) {
- DVLOG(2) << "Failed to read data from: " << model_path.value();
- data.clear();
- }
- return data;
- }
- void SaveToFile(const GURL& model_url,
- const base::FilePath& model_path,
- const std::string& model_data,
- const std::string& uma_prefix) {
- DVLOG(2) << "Saving model from '" << model_url << "'' to '"
- << model_path.value() << "'.";
- MyScopedHistogramTimer timer(uma_prefix + kWriteTimerHistogram);
- base::ImportantFileWriter::WriteFileAtomically(model_path, model_data);
- }
- } // namespace
- RankerModelLoaderImpl::RankerModelLoaderImpl(
- ValidateModelCallback validate_model_cb,
- OnModelAvailableCallback on_model_available_cb,
- scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
- base::FilePath model_path,
- GURL model_url,
- std::string uma_prefix)
- : background_task_runner_(base::ThreadPool::CreateSequencedTaskRunner(
- {base::MayBlock(), base::TaskPriority::BEST_EFFORT,
- base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN})),
- validate_model_cb_(std::move(validate_model_cb)),
- on_model_available_cb_(std::move(on_model_available_cb)),
- url_loader_factory_(std::move(url_loader_factory)),
- model_path_(std::move(model_path)),
- model_url_(std::move(model_url)),
- uma_prefix_(std::move(uma_prefix)),
- url_fetcher_(std::make_unique<RankerURLFetcher>()) {}
- RankerModelLoaderImpl::~RankerModelLoaderImpl() {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- }
- void RankerModelLoaderImpl::NotifyOfRankerActivity() {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- switch (state_) {
- case LoaderState::NOT_STARTED:
- if (!model_path_.empty()) {
- StartLoadFromFile();
- break;
- }
- // There was no configured model path. Switch the state to IDLE and
- // fall through to consider the URL.
- state_ = LoaderState::IDLE;
- [[fallthrough]];
- case LoaderState::IDLE:
- if (model_url_.is_valid()) {
- StartLoadFromURL();
- break;
- }
- // There was no configured model URL. Switch the state to FINISHED and
- // fall through.
- state_ = LoaderState::FINISHED;
- [[fallthrough]];
- case LoaderState::FINISHED:
- case LoaderState::LOADING_FROM_FILE:
- case LoaderState::LOADING_FROM_URL:
- // Nothing to do.
- break;
- }
- }
- void RankerModelLoaderImpl::StartLoadFromFile() {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- DCHECK_EQ(state_, LoaderState::NOT_STARTED);
- DCHECK(!model_path_.empty());
- state_ = LoaderState::LOADING_FROM_FILE;
- load_start_time_ = base::TimeTicks::Now();
- base::PostTaskAndReplyWithResult(
- background_task_runner_.get(), FROM_HERE,
- base::BindOnce(&LoadFromFile, model_path_),
- base::BindOnce(&RankerModelLoaderImpl::OnFileLoaded,
- weak_ptr_factory_.GetWeakPtr()));
- }
- void RankerModelLoaderImpl::OnFileLoaded(const std::string& data) {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- DCHECK_EQ(state_, LoaderState::LOADING_FROM_FILE);
- // Record the duration of the download.
- RecordTimerHistogram(uma_prefix_ + kReadTimerHistogram,
- base::TimeTicks::Now() - load_start_time_);
- // Empty data means |model_path| wasn't successfully read. Otherwise,
- // parse and validate the model.
- std::unique_ptr<RankerModel> model;
- if (data.empty()) {
- ReportModelStatus(RankerModelStatus::LOAD_FROM_CACHE_FAILED);
- } else {
- model = CreateAndValidateModel(data);
- }
- // If |model| is nullptr, then data is empty or the parse failed. Transition
- // to IDLE, from which URL download can be attempted.
- if (!model) {
- state_ = LoaderState::IDLE;
- } else {
- // The model is valid. The client is willing/able to use it. Keep track
- // of where it originated and whether or not is has expired.
- std::string url_spec = model->GetSourceURL();
- bool is_expired = model->IsExpired();
- bool is_finished = url_spec == model_url_.spec() && !is_expired;
- DVLOG(2) << (is_expired ? "Expired m" : "M") << "odel in '"
- << model_path_.value() << "' was originally downloaded from '"
- << url_spec << "'.";
- // If the cached model came from currently configured |model_url_| and has
- // not expired, transition to FINISHED, as there is no need for a model
- // download; otherwise, transition to IDLE.
- state_ = is_finished ? LoaderState::FINISHED : LoaderState::IDLE;
- // Transfer the model to the client.
- on_model_available_cb_.Run(std::move(model));
- }
- // Notify the state machine. This will immediately kick off a download if
- // one is required, instead of waiting for the next organic detection of
- // ranker activity.
- NotifyOfRankerActivity();
- }
- void RankerModelLoaderImpl::StartLoadFromURL() {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- DCHECK_EQ(state_, LoaderState::IDLE);
- DCHECK(model_url_.is_valid());
- // Do nothing if the download attempts should be throttled.
- if (base::TimeTicks::Now() < next_earliest_download_time_) {
- DVLOG(2) << "Last download attempt was too recent.";
- ReportModelStatus(RankerModelStatus::DOWNLOAD_THROTTLED);
- return;
- }
- // Kick off the next download attempt and reset the time of the next earliest
- // allowable download attempt.
- DVLOG(2) << "Downloading model from: " << model_url_;
- state_ = LoaderState::LOADING_FROM_URL;
- load_start_time_ = base::TimeTicks::Now();
- next_earliest_download_time_ =
- load_start_time_ + base::Minutes(kMinRetryDelayMins);
- bool request_started =
- url_fetcher_->Request(model_url_,
- base::BindOnce(&RankerModelLoaderImpl::OnURLFetched,
- weak_ptr_factory_.GetWeakPtr()),
- url_loader_factory_.get());
- // |url_fetcher_| maintains a request retry counter. If all allowed attempts
- // have already been exhausted, then the loader is finished and has abandoned
- // loading the model.
- if (!request_started) {
- DVLOG(2) << "Model download abandoned.";
- ReportModelStatus(RankerModelStatus::MODEL_LOADING_ABANDONED);
- state_ = LoaderState::FINISHED;
- }
- }
- void RankerModelLoaderImpl::OnURLFetched(bool success,
- const std::string& data) {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- DCHECK_EQ(state_, LoaderState::LOADING_FROM_URL);
- // Record the duration of the download.
- RecordTimerHistogram(uma_prefix_ + kDownloadTimerHistogram,
- base::TimeTicks::Now() - load_start_time_);
- // On request failure, transition back to IDLE. The loader will retry, or
- // enforce the max download attempts, later.
- if (!success || data.empty()) {
- DVLOG(2) << "Download from '" << model_url_ << "'' failed.";
- ReportModelStatus(RankerModelStatus::DOWNLOAD_FAILED);
- state_ = LoaderState::IDLE;
- return;
- }
- // Attempt to loads the model. If this fails, transition back to IDLE. The
- // loader will retry, or enfore the max download attempts, later.
- auto model = CreateAndValidateModel(data);
- if (!model) {
- DVLOG(2) << "Model from '" << model_url_ << "'' not valid.";
- state_ = LoaderState::IDLE;
- return;
- }
- // The model is valid. Update the metadata to track the source URL and
- // download timestamp.
- auto* metadata = model->mutable_proto()->mutable_metadata();
- metadata->set_source(model_url_.spec());
- metadata->set_last_modified_sec(
- (base::Time::Now() - base::Time()).InSeconds());
- // Cache the model to model_path_, in the background.
- if (!model_path_.empty()) {
- background_task_runner_->PostTask(
- FROM_HERE, base::BindOnce(&SaveToFile, model_url_, model_path_,
- model->SerializeAsString(), uma_prefix_));
- }
- // The loader is finished. Transfer the model to the client.
- state_ = LoaderState::FINISHED;
- on_model_available_cb_.Run(std::move(model));
- }
- std::unique_ptr<RankerModel> RankerModelLoaderImpl::CreateAndValidateModel(
- const std::string& data) {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- MyScopedHistogramTimer timer(uma_prefix_ + kParsetimerHistogram);
- auto model = RankerModel::FromString(data);
- if (ReportModelStatus(model ? validate_model_cb_.Run(*model)
- : RankerModelStatus::PARSE_FAILED) !=
- RankerModelStatus::OK) {
- return nullptr;
- }
- return model;
- }
- RankerModelStatus RankerModelLoaderImpl::ReportModelStatus(
- RankerModelStatus model_status) {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- base::HistogramBase* histogram = base::LinearHistogram::FactoryGet(
- uma_prefix_ + kModelStatusHistogram, 1,
- static_cast<int>(RankerModelStatus::MAX),
- static_cast<int>(RankerModelStatus::MAX) + 1,
- base::HistogramBase::kUmaTargetedHistogramFlag);
- if (histogram)
- histogram->Add(static_cast<int>(model_status));
- return model_status;
- }
- } // namespace assist_ranker
|