base_model_executor_helpers.h 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. // Copyright 2021 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_BASE_MODEL_EXECUTOR_HELPERS_H_
  5. #define COMPONENTS_OPTIMIZATION_GUIDE_CORE_BASE_MODEL_EXECUTOR_HELPERS_H_
  6. #include <memory>
  7. #include <vector>
  8. #include "base/check.h"
  9. #include "base/memory/raw_ptr.h"
  10. #include "components/optimization_guide/core/execution_status.h"
  11. #include "third_party/abseil-cpp/absl/types/optional.h"
  12. #include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h"
  13. namespace optimization_guide {
  14. template <class OutputType, class... InputTypes>
  15. class InferenceDelegate {
  16. public:
  17. // Preprocesses |args| into |input_tensors|. Returns true on success.
  18. virtual bool Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
  19. InputTypes... args) = 0;
  20. // Postprocesses |output_tensors| into the desired |OutputType|, returning
  21. // absl::nullopt on error.
  22. virtual absl::optional<OutputType> Postprocess(
  23. const std::vector<const TfLiteTensor*>& output_tensors) = 0;
  24. };
  25. template <class OutputType, class... InputTypes>
  26. class GenericModelExecutionTask
  27. : public tflite::task::core::BaseTaskApi<OutputType, InputTypes...> {
  28. public:
  29. GenericModelExecutionTask(
  30. std::unique_ptr<tflite::task::core::TfLiteEngine> tflite_engine,
  31. InferenceDelegate<OutputType, InputTypes...>* delegate)
  32. : tflite::task::core::BaseTaskApi<OutputType, InputTypes...>(
  33. std::move(tflite_engine)),
  34. delegate_(delegate) {
  35. DCHECK(delegate_);
  36. }
  37. ~GenericModelExecutionTask() override = default;
  38. // Executes the model using |args| and returns the output if the model was
  39. // executed successfully.
  40. absl::optional<OutputType> Execute(ExecutionStatus* out_status,
  41. InputTypes... args) {
  42. tflite::support::StatusOr<OutputType> maybe_output = this->Infer(args...);
  43. if (absl::IsCancelled(maybe_output.status())) {
  44. *out_status = ExecutionStatus::kErrorCancelled;
  45. return absl::nullopt;
  46. }
  47. if (!maybe_output.ok()) {
  48. *out_status = ExecutionStatus::kErrorUnknown;
  49. return absl::nullopt;
  50. }
  51. *out_status = ExecutionStatus::kSuccess;
  52. return maybe_output.value();
  53. }
  54. protected:
  55. // BaseTaskApi:
  56. absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
  57. InputTypes... args) override {
  58. bool success = delegate_->Preprocess(input_tensors, args...);
  59. if (success) {
  60. return absl::OkStatus();
  61. }
  62. return absl::InternalError(
  63. "error during preprocessing. See stderr for more information if "
  64. "available");
  65. }
  66. tflite::support::StatusOr<OutputType> Postprocess(
  67. const std::vector<const TfLiteTensor*>& output_tensors,
  68. InputTypes... api_inputs) override {
  69. absl::optional<OutputType> output = delegate_->Postprocess(output_tensors);
  70. if (!output) {
  71. return absl::InternalError(
  72. "error during postprocessing. See stderr for more infomation if "
  73. "available");
  74. }
  75. return *output;
  76. }
  77. private:
  78. // Guaranteed to outlive this.
  79. raw_ptr<InferenceDelegate<OutputType, InputTypes...>> delegate_;
  80. };
  81. } // namespace optimization_guide
  82. #endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_BASE_MODEL_EXECUTOR_HELPERS_H_