bert_model_executor.cc 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. // Copyright 2021 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/optimization_guide/core/bert_model_executor.h"
  5. #include "base/trace_event/trace_event.h"
  6. #include "components/optimization_guide/core/model_util.h"
  7. #include "components/optimization_guide/core/optimization_guide_features.h"
  8. #include "components/optimization_guide/core/tflite_op_resolver.h"
  9. #include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h"
  10. namespace optimization_guide {
  11. BertModelExecutor::BertModelExecutor(
  12. proto::OptimizationTarget optimization_target)
  13. : optimization_target_(optimization_target),
  14. num_threads_(features::OverrideNumThreadsForOptTarget(optimization_target)
  15. .value_or(-1)) {}
  16. BertModelExecutor::~BertModelExecutor() = default;
  17. absl::optional<std::vector<tflite::task::core::Category>>
  18. BertModelExecutor::Execute(ModelExecutionTask* execution_task,
  19. ExecutionStatus* out_status,
  20. const std::string& input) {
  21. if (input.empty()) {
  22. *out_status = ExecutionStatus::kErrorEmptyOrInvalidInput;
  23. return absl::nullopt;
  24. }
  25. TRACE_EVENT2("browser", "BertModelExecutor::Execute", "optimization_target",
  26. GetStringNameForOptimizationTarget(optimization_target_),
  27. "input_length", input.size());
  28. auto status_or_result =
  29. static_cast<tflite::task::text::BertNLClassifier*>(execution_task)
  30. ->ClassifyText(input);
  31. if (absl::IsCancelled(status_or_result.status())) {
  32. *out_status = ExecutionStatus::kErrorCancelled;
  33. return absl::nullopt;
  34. }
  35. if (!status_or_result.ok()) {
  36. *out_status = ExecutionStatus::kErrorUnknown;
  37. return absl::nullopt;
  38. }
  39. *out_status = ExecutionStatus::kSuccess;
  40. return *status_or_result;
  41. }
  42. std::unique_ptr<BertModelExecutor::ModelExecutionTask>
  43. BertModelExecutor::BuildModelExecutionTask(base::MemoryMappedFile* model_file,
  44. ExecutionStatus* out_status) {
  45. tflite::task::text::BertNLClassifierOptions options;
  46. *options.mutable_base_options()
  47. ->mutable_model_file()
  48. ->mutable_file_content() = std::string(
  49. reinterpret_cast<const char*>(model_file->data()), model_file->length());
  50. options.mutable_base_options()
  51. ->mutable_compute_settings()
  52. ->mutable_tflite_settings()
  53. ->mutable_cpu_settings()
  54. ->set_num_threads(num_threads_);
  55. auto maybe_nl_classifier =
  56. tflite::task::text::BertNLClassifier::CreateFromOptions(
  57. std::move(options), std::make_unique<TFLiteOpResolver>());
  58. if (maybe_nl_classifier.ok())
  59. return std::move(maybe_nl_classifier.value());
  60. *out_status = ExecutionStatus::kErrorModelFileNotValid;
  61. DLOG(ERROR) << "Unable to load BERT model: "
  62. << maybe_nl_classifier.status().ToString();
  63. return nullptr;
  64. }
  65. } // namespace optimization_guide