0006-check-cancel-flag-before-calling-invoke.patch 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. From 98f819d7d88b6f03b3bbab2d116d2fa31674a154 Mon Sep 17 00:00:00 2001
  2. From: Robert Ogden <robertogden@chromium.org>
  3. Date: Wed, 25 May 2022 10:54:30 -0700
  4. Subject: [PATCH 6/9] check cancel flag before calling invoke
  5. ---
  6. .../cc/port/default/tflite_wrapper.cc | 14 ++++++++++----
  7. 1 file changed, 10 insertions(+), 4 deletions(-)
  8. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
  9. index d47c1ce7e5179..11f9d584cfdd0 100644
  10. --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
  11. +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
  12. @@ -258,8 +258,10 @@ absl::Status TfLiteInterpreterWrapper::InvokeWithFallback(
  13. const std::function<absl::Status(tflite::Interpreter* interpreter)>&
  14. set_inputs) {
  15. RETURN_IF_ERROR(set_inputs(interpreter_.get()));
  16. - // Reset cancel flag before calling `Invoke()`.
  17. - cancel_flag_.Set(false);
  18. + if (cancel_flag_.Get()) {
  19. + cancel_flag_.Set(false);
  20. + return absl::CancelledError("cancelled before Invoke() was called");
  21. + }
  22. TfLiteStatus status = kTfLiteError;
  23. if (fallback_on_execution_error_) {
  24. status = InterpreterUtils::InvokeWithCPUFallback(interpreter_.get());
  25. @@ -273,6 +275,7 @@ absl::Status TfLiteInterpreterWrapper::InvokeWithFallback(
  26. // Assume the inference is cancelled successfully if Invoke() returns
  27. // kTfLiteError and the cancel flag is `true`.
  28. if (status == kTfLiteError && cancel_flag_.Get()) {
  29. + cancel_flag_.Set(false);
  30. return absl::CancelledError("Invoke() cancelled.");
  31. }
  32. if (delegate_) {
  33. @@ -289,14 +292,17 @@ absl::Status TfLiteInterpreterWrapper::InvokeWithFallback(
  34. }
  35. absl::Status TfLiteInterpreterWrapper::InvokeWithoutFallback() {
  36. - // Reset cancel flag before calling `Invoke()`.
  37. - cancel_flag_.Set(false);
  38. + if (cancel_flag_.Get()) {
  39. + cancel_flag_.Set(false);
  40. + return absl::CancelledError("cancelled before Invoke() was called");
  41. + }
  42. TfLiteStatus status = interpreter_->Invoke();
  43. if (status != kTfLiteOk) {
  44. // Assume InvokeWithoutFallback() is guarded under caller's synchronization.
  45. // Assume the inference is cancelled successfully if Invoke() returns
  46. // kTfLiteError and the cancel flag is `true`.
  47. if (status == kTfLiteError && cancel_flag_.Get()) {
  48. + cancel_flag_.Set(false);
  49. return absl::CancelledError("Invoke() cancelled.");
  50. }
  51. return absl::InternalError("Invoke() failed.");
  52. --
  53. 2.36.1.124.g0e6072fb45-goog