// Copyright 2017 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "components/assist_ranker/ranker_model_loader_impl.h" #include #include #include "base/bind.h" #include "base/callback_helpers.h" #include "base/command_line.h" #include "base/files/file_util.h" #include "base/files/important_file_writer.h" #include "base/metrics/histogram_macros.h" #include "base/strings/string_util.h" #include "base/task/sequenced_task_runner.h" #include "base/task/task_runner_util.h" #include "base/task/thread_pool.h" #include "base/threading/sequenced_task_runner_handle.h" #include "components/assist_ranker/proto/ranker_model.pb.h" #include "components/assist_ranker/ranker_url_fetcher.h" #include "services/network/public/cpp/shared_url_loader_factory.h" namespace assist_ranker { namespace { // The minimum duration, in minutes, between download attempts. constexpr int kMinRetryDelayMins = 3; // Suffixes for the various histograms produced by the backend. const char kWriteTimerHistogram[] = ".Timer.WriteModel"; const char kReadTimerHistogram[] = ".Timer.ReadModel"; const char kDownloadTimerHistogram[] = ".Timer.DownloadModel"; const char kParsetimerHistogram[] = ".Timer.ParseModel"; const char kModelStatusHistogram[] = ".Model.Status"; // Helper function to UMA log a timer histograms. void RecordTimerHistogram(const std::string& name, base::TimeDelta duration) { base::HistogramBase* counter = base::Histogram::FactoryTimeGet( name, base::Milliseconds(10), base::Milliseconds(200000), 100, base::HistogramBase::kUmaTargetedHistogramFlag); DCHECK(counter); counter->AddTime(duration); } // A helper class to produce a scoped timer histogram that supports using a // non-static-const name. class MyScopedHistogramTimer { public: MyScopedHistogramTimer(const base::StringPiece& name) : name_(name.begin(), name.end()), start_(base::TimeTicks::Now()) {} MyScopedHistogramTimer(const MyScopedHistogramTimer&) = delete; MyScopedHistogramTimer& operator=(const MyScopedHistogramTimer&) = delete; ~MyScopedHistogramTimer() { RecordTimerHistogram(name_, base::TimeTicks::Now() - start_); } private: const std::string name_; const base::TimeTicks start_; }; std::string LoadFromFile(const base::FilePath& model_path) { DCHECK(!model_path.empty()); DVLOG(2) << "Reading data from: " << model_path.value(); std::string data; if (!base::ReadFileToString(model_path, &data) || data.empty()) { DVLOG(2) << "Failed to read data from: " << model_path.value(); data.clear(); } return data; } void SaveToFile(const GURL& model_url, const base::FilePath& model_path, const std::string& model_data, const std::string& uma_prefix) { DVLOG(2) << "Saving model from '" << model_url << "'' to '" << model_path.value() << "'."; MyScopedHistogramTimer timer(uma_prefix + kWriteTimerHistogram); base::ImportantFileWriter::WriteFileAtomically(model_path, model_data); } } // namespace RankerModelLoaderImpl::RankerModelLoaderImpl( ValidateModelCallback validate_model_cb, OnModelAvailableCallback on_model_available_cb, scoped_refptr url_loader_factory, base::FilePath model_path, GURL model_url, std::string uma_prefix) : background_task_runner_(base::ThreadPool::CreateSequencedTaskRunner( {base::MayBlock(), base::TaskPriority::BEST_EFFORT, base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN})), validate_model_cb_(std::move(validate_model_cb)), on_model_available_cb_(std::move(on_model_available_cb)), url_loader_factory_(std::move(url_loader_factory)), model_path_(std::move(model_path)), model_url_(std::move(model_url)), uma_prefix_(std::move(uma_prefix)), url_fetcher_(std::make_unique()) {} RankerModelLoaderImpl::~RankerModelLoaderImpl() { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); } void RankerModelLoaderImpl::NotifyOfRankerActivity() { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); switch (state_) { case LoaderState::NOT_STARTED: if (!model_path_.empty()) { StartLoadFromFile(); break; } // There was no configured model path. Switch the state to IDLE and // fall through to consider the URL. state_ = LoaderState::IDLE; [[fallthrough]]; case LoaderState::IDLE: if (model_url_.is_valid()) { StartLoadFromURL(); break; } // There was no configured model URL. Switch the state to FINISHED and // fall through. state_ = LoaderState::FINISHED; [[fallthrough]]; case LoaderState::FINISHED: case LoaderState::LOADING_FROM_FILE: case LoaderState::LOADING_FROM_URL: // Nothing to do. break; } } void RankerModelLoaderImpl::StartLoadFromFile() { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_EQ(state_, LoaderState::NOT_STARTED); DCHECK(!model_path_.empty()); state_ = LoaderState::LOADING_FROM_FILE; load_start_time_ = base::TimeTicks::Now(); base::PostTaskAndReplyWithResult( background_task_runner_.get(), FROM_HERE, base::BindOnce(&LoadFromFile, model_path_), base::BindOnce(&RankerModelLoaderImpl::OnFileLoaded, weak_ptr_factory_.GetWeakPtr())); } void RankerModelLoaderImpl::OnFileLoaded(const std::string& data) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_EQ(state_, LoaderState::LOADING_FROM_FILE); // Record the duration of the download. RecordTimerHistogram(uma_prefix_ + kReadTimerHistogram, base::TimeTicks::Now() - load_start_time_); // Empty data means |model_path| wasn't successfully read. Otherwise, // parse and validate the model. std::unique_ptr model; if (data.empty()) { ReportModelStatus(RankerModelStatus::LOAD_FROM_CACHE_FAILED); } else { model = CreateAndValidateModel(data); } // If |model| is nullptr, then data is empty or the parse failed. Transition // to IDLE, from which URL download can be attempted. if (!model) { state_ = LoaderState::IDLE; } else { // The model is valid. The client is willing/able to use it. Keep track // of where it originated and whether or not is has expired. std::string url_spec = model->GetSourceURL(); bool is_expired = model->IsExpired(); bool is_finished = url_spec == model_url_.spec() && !is_expired; DVLOG(2) << (is_expired ? "Expired m" : "M") << "odel in '" << model_path_.value() << "' was originally downloaded from '" << url_spec << "'."; // If the cached model came from currently configured |model_url_| and has // not expired, transition to FINISHED, as there is no need for a model // download; otherwise, transition to IDLE. state_ = is_finished ? LoaderState::FINISHED : LoaderState::IDLE; // Transfer the model to the client. on_model_available_cb_.Run(std::move(model)); } // Notify the state machine. This will immediately kick off a download if // one is required, instead of waiting for the next organic detection of // ranker activity. NotifyOfRankerActivity(); } void RankerModelLoaderImpl::StartLoadFromURL() { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_EQ(state_, LoaderState::IDLE); DCHECK(model_url_.is_valid()); // Do nothing if the download attempts should be throttled. if (base::TimeTicks::Now() < next_earliest_download_time_) { DVLOG(2) << "Last download attempt was too recent."; ReportModelStatus(RankerModelStatus::DOWNLOAD_THROTTLED); return; } // Kick off the next download attempt and reset the time of the next earliest // allowable download attempt. DVLOG(2) << "Downloading model from: " << model_url_; state_ = LoaderState::LOADING_FROM_URL; load_start_time_ = base::TimeTicks::Now(); next_earliest_download_time_ = load_start_time_ + base::Minutes(kMinRetryDelayMins); bool request_started = url_fetcher_->Request(model_url_, base::BindOnce(&RankerModelLoaderImpl::OnURLFetched, weak_ptr_factory_.GetWeakPtr()), url_loader_factory_.get()); // |url_fetcher_| maintains a request retry counter. If all allowed attempts // have already been exhausted, then the loader is finished and has abandoned // loading the model. if (!request_started) { DVLOG(2) << "Model download abandoned."; ReportModelStatus(RankerModelStatus::MODEL_LOADING_ABANDONED); state_ = LoaderState::FINISHED; } } void RankerModelLoaderImpl::OnURLFetched(bool success, const std::string& data) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_EQ(state_, LoaderState::LOADING_FROM_URL); // Record the duration of the download. RecordTimerHistogram(uma_prefix_ + kDownloadTimerHistogram, base::TimeTicks::Now() - load_start_time_); // On request failure, transition back to IDLE. The loader will retry, or // enforce the max download attempts, later. if (!success || data.empty()) { DVLOG(2) << "Download from '" << model_url_ << "'' failed."; ReportModelStatus(RankerModelStatus::DOWNLOAD_FAILED); state_ = LoaderState::IDLE; return; } // Attempt to loads the model. If this fails, transition back to IDLE. The // loader will retry, or enfore the max download attempts, later. auto model = CreateAndValidateModel(data); if (!model) { DVLOG(2) << "Model from '" << model_url_ << "'' not valid."; state_ = LoaderState::IDLE; return; } // The model is valid. Update the metadata to track the source URL and // download timestamp. auto* metadata = model->mutable_proto()->mutable_metadata(); metadata->set_source(model_url_.spec()); metadata->set_last_modified_sec( (base::Time::Now() - base::Time()).InSeconds()); // Cache the model to model_path_, in the background. if (!model_path_.empty()) { background_task_runner_->PostTask( FROM_HERE, base::BindOnce(&SaveToFile, model_url_, model_path_, model->SerializeAsString(), uma_prefix_)); } // The loader is finished. Transfer the model to the client. state_ = LoaderState::FINISHED; on_model_available_cb_.Run(std::move(model)); } std::unique_ptr RankerModelLoaderImpl::CreateAndValidateModel( const std::string& data) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); MyScopedHistogramTimer timer(uma_prefix_ + kParsetimerHistogram); auto model = RankerModel::FromString(data); if (ReportModelStatus(model ? validate_model_cb_.Run(*model) : RankerModelStatus::PARSE_FAILED) != RankerModelStatus::OK) { return nullptr; } return model; } RankerModelStatus RankerModelLoaderImpl::ReportModelStatus( RankerModelStatus model_status) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); base::HistogramBase* histogram = base::LinearHistogram::FactoryGet( uma_prefix_ + kModelStatusHistogram, 1, static_cast(RankerModelStatus::MAX), static_cast(RankerModelStatus::MAX) + 1, base::HistogramBase::kUmaTargetedHistogramFlag); if (histogram) histogram->Add(static_cast(model_status)); return model_status; } } // namespace assist_ranker