tflite_model_executor.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. // Copyright 2021 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_TFLITE_MODEL_EXECUTOR_H_
  5. #define COMPONENTS_OPTIMIZATION_GUIDE_CORE_TFLITE_MODEL_EXECUTOR_H_
  6. #include "base/bind.h"
  7. #include "base/callback_forward.h"
  8. #include "base/files/memory_mapped_file.h"
  9. #include "base/logging.h"
  10. #include "base/memory/weak_ptr.h"
  11. #include "base/metrics/histogram.h"
  12. #include "base/metrics/histogram_functions.h"
  13. #include "base/sequence_checker.h"
  14. #include "base/task/thread_pool.h"
  15. #include "base/threading/sequenced_task_runner_handle.h"
  16. #include "base/time/time.h"
  17. #include "base/timer/elapsed_timer.h"
  18. #include "base/trace_event/trace_event.h"
  19. #include "components/optimization_guide/core/execution_status.h"
  20. #include "components/optimization_guide/core/model_enums.h"
  21. #include "components/optimization_guide/core/model_execution_timeout_watchdog.h"
  22. #include "components/optimization_guide/core/model_executor.h"
  23. #include "components/optimization_guide/core/model_util.h"
  24. #include "components/optimization_guide/core/optimization_guide_features.h"
  25. #include "third_party/abseil-cpp/absl/types/optional.h"
  26. #include "third_party/tflite/src/tensorflow/lite/c/common.h"
  27. #include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h"
  28. namespace optimization_guide {
  29. namespace {
  30. // Util class for recording the result of the model execution. The result is
  31. // recorded when it goes out of scope and its destructor is called.
  32. class ScopedExecutionStatusResultRecorder {
  33. public:
  34. explicit ScopedExecutionStatusResultRecorder(
  35. proto::OptimizationTarget optimization_target)
  36. : optimization_target_(optimization_target) {}
  37. ~ScopedExecutionStatusResultRecorder() {
  38. base::UmaHistogramEnumeration(
  39. "OptimizationGuide.ModelExecutor.ExecutionStatus." +
  40. optimization_guide::GetStringNameForOptimizationTarget(
  41. optimization_target_),
  42. status_);
  43. }
  44. ExecutionStatus* mutable_status() { return &status_; }
  45. ExecutionStatus status() const { return status_; }
  46. void set_status(ExecutionStatus status) { status_ = status; }
  47. private:
  48. // The OptimizationTarget of the model being executed.
  49. const proto::OptimizationTarget optimization_target_;
  50. ExecutionStatus status_ = ExecutionStatus::kUnknown;
  51. };
  52. } // namespace
  53. // An ModelExecutor that executes tflite models with arbitrary
  54. // input and output types. Note that callers will need to give an implementation
  55. // of this class to a |ModelHandler|, whereas the
  56. // handle is the actual class that calling code would own and call into.
  57. //
  58. // By default, the model file will be (re)loaded for every execution and then
  59. // unloaded from memory after every execution (e.g.: "OnComplete"). This helps
  60. // to keep memory usage of the browser process down, but does delay model
  61. // execution by the time it takes to load the model (about 50ms in practice).
  62. // See |SetShouldUnloadModelOnComplete| to override this behavior.
  63. template <class OutputType, class... InputTypes>
  64. class TFLiteModelExecutor : public ModelExecutor<OutputType, InputTypes...> {
  65. public:
  66. TFLiteModelExecutor()
  67. : watchdog_(nullptr, base::OnTaskRunnerDeleter(nullptr)) {}
  68. ~TFLiteModelExecutor() override {
  69. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  70. }
  71. // Should be called on the same sequence as the ctor, but once called |this|
  72. // must only be used from the |execution_task_runner| thread/sequence.
  73. void InitializeAndMoveToExecutionThread(
  74. absl::optional<base::TimeDelta> model_inference_timeout,
  75. proto::OptimizationTarget optimization_target,
  76. scoped_refptr<base::SequencedTaskRunner> execution_task_runner,
  77. scoped_refptr<base::SequencedTaskRunner> reply_task_runner) override {
  78. DCHECK(!execution_task_runner_);
  79. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  80. DCHECK_NE(optimization_target,
  81. proto::OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN);
  82. DETACH_FROM_SEQUENCE(sequence_checker_);
  83. optimization_target_ = optimization_target;
  84. execution_task_runner_ = execution_task_runner;
  85. reply_task_runner_ = reply_task_runner;
  86. if (features::IsModelExecutionWatchdogEnabled()) {
  87. // The sequence |watchdog_sequence| is used to run watchdog's task. The
  88. // watchdog must be deleted on that sequence to guarantee that pending
  89. // tasks can safely be executed.
  90. scoped_refptr<base::SequencedTaskRunner> watchdog_sequence =
  91. base::ThreadPool::CreateSequencedTaskRunner({base::MayBlock()});
  92. using WatchdogType =
  93. ModelExecutionTimeoutWatchdog<OutputType, InputTypes...>;
  94. watchdog_ = std::unique_ptr<WatchdogType, base::OnTaskRunnerDeleter>(
  95. new WatchdogType(
  96. watchdog_sequence, optimization_target_,
  97. model_inference_timeout.value_or(
  98. features::ModelExecutionWatchdogDefaultTimeout())),
  99. base::OnTaskRunnerDeleter(watchdog_sequence));
  100. }
  101. }
  102. // Called when a model file is available to load. Depending on feature flags,
  103. // the model may or may not be immediately loaded.
  104. void UpdateModelFile(const base::FilePath& file_path) override {
  105. DCHECK(execution_task_runner_ &&
  106. execution_task_runner_->RunsTasksInCurrentSequence());
  107. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  108. UnloadModel();
  109. model_file_path_ = file_path;
  110. // crbug/1257189: Histogram enums can't use dynamically created histogram
  111. // names, so factory create the local histogram (used in testing).
  112. base::HistogramBase* histogram = base::BooleanHistogram::FactoryGet(
  113. "OptimizationGuide.ModelExecutor.ModelFileUpdated." +
  114. optimization_guide::GetStringNameForOptimizationTarget(
  115. optimization_target_),
  116. base::Histogram::kNoFlags);
  117. histogram->Add(true);
  118. }
  119. // Calling this method allows the default model loading/unloading behavior to
  120. // be overridden. Setting this to true will cause the model to remain loaded
  121. // afterwards a model execution (e.g.: "OnComplete"), until |UnloadModel| is
  122. // called. False is the default behavior (see class comment).
  123. void SetShouldUnloadModelOnComplete(
  124. bool should_unload_model_on_complete) override {
  125. DCHECK(execution_task_runner_->RunsTasksInCurrentSequence());
  126. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  127. should_unload_model_on_complete_ = should_unload_model_on_complete;
  128. }
  129. // Clears the loaded model from memory if it is loaded. Safe to call when the
  130. // model is already unloaded, and becomes a no-op.
  131. void UnloadModel() override {
  132. TRACE_EVENT1("browser", "OptGuideModelExecutor::UnloadModel",
  133. "OptimizationTarget",
  134. optimization_guide::GetStringNameForOptimizationTarget(
  135. optimization_target_));
  136. DCHECK(execution_task_runner_->RunsTasksInCurrentSequence());
  137. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  138. loaded_model_.reset();
  139. model_fb_.reset();
  140. }
  141. // Starts the execution of the model. When complete, |callback_on_complete|
  142. // will be run via |reply_task_runner_| with the output of the model.
  143. using ExecutionCallback =
  144. base::OnceCallback<void(const absl::optional<OutputType>&)>;
  145. void SendForExecution(ExecutionCallback callback_on_complete,
  146. base::TimeTicks start_time,
  147. InputTypes... args) override {
  148. DCHECK(execution_task_runner_->RunsTasksInCurrentSequence());
  149. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  150. DCHECK(reply_task_runner_);
  151. base::TimeDelta task_scheduling_latency =
  152. base::TimeTicks::Now() - start_time;
  153. base::UmaHistogramMediumTimes(
  154. "OptimizationGuide.ModelExecutor.TaskSchedulingLatency." +
  155. optimization_guide::GetStringNameForOptimizationTarget(
  156. optimization_target_),
  157. task_scheduling_latency);
  158. ScopedExecutionStatusResultRecorder status_recorder(optimization_target_);
  159. // Attempt to load the model file if it isn't loaded yet, fail if loading is
  160. // unsuccessful or no model is available to load.
  161. if (!loaded_model_ && !LoadModelFile(status_recorder.mutable_status())) {
  162. reply_task_runner_->PostTask(
  163. FROM_HERE,
  164. base::BindOnce(std::move(callback_on_complete), absl::nullopt));
  165. // Some error status is expected, and derived classes should have set the
  166. // status.
  167. DCHECK_NE(status_recorder.status(), ExecutionStatus::kUnknown);
  168. DCHECK_NE(status_recorder.status(), ExecutionStatus::kSuccess);
  169. return;
  170. }
  171. if (last_execution_time_) {
  172. // The max of this histogram is 3m since only the distribution and count
  173. // of smaller values is important.
  174. base::UmaHistogramMediumTimes(
  175. "OptimizationGuide.ModelExecutor.TimeSincePreviousRun." +
  176. GetStringNameForOptimizationTarget(optimization_target_),
  177. base::TimeTicks::Now() - *last_execution_time_);
  178. }
  179. last_execution_time_ = base::TimeTicks::Now();
  180. DCHECK(loaded_model_);
  181. absl::optional<OutputType> output;
  182. // IMPORTANT: Once the arm method is called, disarm must be called when the
  183. // model execution finishes. Do NOT early-return in this next block.
  184. if (watchdog_) {
  185. watchdog_->ArmWithTask(loaded_model_.get());
  186. }
  187. {
  188. TRACE_EVENT1("browser", "OptGuideModelExecutor::Execute",
  189. "OptimizationTarget",
  190. optimization_guide::GetStringNameForOptimizationTarget(
  191. optimization_target_));
  192. base::ElapsedThreadTimer execution_timer;
  193. base::TimeTicks execute_start_time = base::TimeTicks::Now();
  194. output = Execute(loaded_model_.get(), status_recorder.mutable_status(),
  195. args...);
  196. DCHECK_NE(status_recorder.status(), ExecutionStatus::kUnknown);
  197. // The max of this histogram is 1 hour because we want to understand
  198. // tail behavior and catch long running model executions.
  199. base::UmaHistogramLongTimes(
  200. "OptimizationGuide.ModelExecutor.ExecutionLatency." +
  201. GetStringNameForOptimizationTarget(optimization_target_),
  202. base::TimeTicks::Now() - execute_start_time);
  203. base::UmaHistogramLongTimes(
  204. "OptimizationGuide.ModelExecutor.ExecutionThreadTime." +
  205. GetStringNameForOptimizationTarget(optimization_target_),
  206. execution_timer.Elapsed());
  207. }
  208. if (watchdog_) {
  209. watchdog_->DisarmOnExecutionComplete();
  210. }
  211. DCHECK(callback_on_complete);
  212. reply_task_runner_->PostTask(
  213. FROM_HERE, base::BindOnce(std::move(callback_on_complete), output));
  214. OnExecutionComplete();
  215. }
  216. TFLiteModelExecutor(const TFLiteModelExecutor&) = delete;
  217. TFLiteModelExecutor& operator=(const TFLiteModelExecutor&) = delete;
  218. protected:
  219. using ModelExecutionTask =
  220. tflite::task::core::BaseTaskApi<OutputType, InputTypes...>;
  221. // Executes the model using |execution_task| on |args|, returning the model
  222. // output and setting |out_status| with the status of the execution attempt.
  223. virtual absl::optional<OutputType> Execute(ModelExecutionTask* execution_task,
  224. ExecutionStatus* out_status,
  225. InputTypes... args) = 0;
  226. // Builds a model execution task using |model_file|.
  227. virtual std::unique_ptr<ModelExecutionTask> BuildModelExecutionTask(
  228. base::MemoryMappedFile* model_file,
  229. ExecutionStatus* out_status) = 0;
  230. private:
  231. // A true return value indicates the model was loaded successfully, false
  232. // otherwise.
  233. bool LoadModelFile(ExecutionStatus* out_status) {
  234. TRACE_EVENT1("browser", "OptGuideModelExecutor::LoadModelFile",
  235. "OptimizationTarget",
  236. optimization_guide::GetStringNameForOptimizationTarget(
  237. optimization_target_));
  238. DCHECK(execution_task_runner_->RunsTasksInCurrentSequence());
  239. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  240. UnloadModel();
  241. base::UmaHistogramBoolean(
  242. "OptimizationGuide.ModelExecutor.ModelAvailableToLoad." +
  243. GetStringNameForOptimizationTarget(optimization_target_),
  244. !!model_file_path_);
  245. if (!model_file_path_) {
  246. *out_status = ExecutionStatus::kErrorModelFileNotAvailable;
  247. return false;
  248. }
  249. base::TimeTicks loading_start_time = base::TimeTicks::Now();
  250. std::unique_ptr<base::MemoryMappedFile> model_fb =
  251. std::make_unique<base::MemoryMappedFile>();
  252. if (!model_fb->Initialize(*model_file_path_)) {
  253. *out_status = ExecutionStatus::kErrorModelFileNotValid;
  254. return false;
  255. }
  256. model_fb_ = std::move(model_fb);
  257. loaded_model_ = BuildModelExecutionTask(model_fb_.get(), out_status);
  258. if (!!loaded_model_) {
  259. // We only want to record successful loading times.
  260. base::UmaHistogramTimes(
  261. "OptimizationGuide.ModelExecutor.ModelLoadingDuration2." +
  262. optimization_guide::GetStringNameForOptimizationTarget(
  263. optimization_target_),
  264. base::TimeTicks::Now() - loading_start_time);
  265. }
  266. // Local histogram used in integration testing.
  267. base::BooleanHistogram::FactoryGet(
  268. "OptimizationGuide.ModelExecutor.ModelLoadedSuccessfully." +
  269. optimization_guide::GetStringNameForOptimizationTarget(
  270. optimization_target_),
  271. base::Histogram::kNoFlags)
  272. ->Add(!!loaded_model_);
  273. return !!loaded_model_;
  274. }
  275. void OnExecutionComplete() {
  276. DCHECK(execution_task_runner_->RunsTasksInCurrentSequence());
  277. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  278. if (should_unload_model_on_complete_) {
  279. UnloadModel();
  280. }
  281. }
  282. proto::OptimizationTarget optimization_target_ =
  283. proto::OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN;
  284. bool should_unload_model_on_complete_ = true;
  285. std::unique_ptr<ModelExecutionTimeoutWatchdog<OutputType, InputTypes...>,
  286. base::OnTaskRunnerDeleter>
  287. watchdog_;
  288. scoped_refptr<base::SequencedTaskRunner> execution_task_runner_;
  289. scoped_refptr<base::SequencedTaskRunner> reply_task_runner_;
  290. // The time that the model was last executed. Logged in metrics for the second
  291. // and following runs.
  292. absl::optional<base::TimeTicks> last_execution_time_
  293. GUARDED_BY_CONTEXT(sequence_checker_);
  294. // The model file path to be loaded. May be nullopt if no model has been
  295. // downloaded yet.
  296. absl::optional<base::FilePath> model_file_path_
  297. GUARDED_BY_CONTEXT(sequence_checker_);
  298. // Note on lifetimes: |loaded_model_| and |model_fb_| both share the same
  299. // lifetime, being set in |LoadModelFile()| and being destroyed in
  300. // |ResetModelFile()|.
  301. std::unique_ptr<ModelExecutionTask> loaded_model_
  302. GUARDED_BY_CONTEXT(sequence_checker_);
  303. // This will only be non-null when |model_file_path_| is set, and while the
  304. // model is loaded which is managed by a feature flag.
  305. std::unique_ptr<base::MemoryMappedFile> model_fb_
  306. GUARDED_BY_CONTEXT(sequence_checker_);
  307. SEQUENCE_CHECKER(sequence_checker_);
  308. };
  309. } // namespace optimization_guide
  310. #endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_TFLITE_MODEL_EXECUTOR_H_