prediction_model_executor.cc 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. // Copyright 2021 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/permissions/prediction_service/prediction_model_executor.h"
  5. #include "base/notreached.h"
  6. #include "components/permissions/prediction_service/prediction_common.h"
  7. #include "components/permissions/prediction_service/prediction_request_features.h"
  8. #include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h"
  9. namespace permissions {
  10. PredictionModelExecutor::PredictionModelExecutor() = default;
  11. PredictionModelExecutor::~PredictionModelExecutor() = default;
  12. bool PredictionModelExecutor::Preprocess(
  13. const std::vector<TfLiteTensor*>& input_tensors,
  14. const GeneratePredictionsRequest& input) {
  15. switch (input.permission_features()[0].permission_type_case()) {
  16. case PermissionFeatures::kNotificationPermission:
  17. request_type_ = RequestType::kNotifications;
  18. break;
  19. case PermissionFeatures::kGeolocationPermission:
  20. request_type_ = RequestType::kGeolocation;
  21. break;
  22. default:
  23. NOTREACHED();
  24. }
  25. if (!tflite::task::core::PopulateTensor<float>(
  26. input.client_features().client_stats().avg_deny_rate(),
  27. input_tensors[0])
  28. .ok()) {
  29. return false;
  30. }
  31. if (!tflite::task::core::PopulateTensor<float>(
  32. input.client_features().client_stats().avg_dismiss_rate(),
  33. input_tensors[1])
  34. .ok()) {
  35. return false;
  36. }
  37. if (!tflite::task::core::PopulateTensor<float>(
  38. input.client_features().client_stats().avg_grant_rate(),
  39. input_tensors[2])
  40. .ok()) {
  41. return false;
  42. }
  43. if (!tflite::task::core::PopulateTensor<float>(
  44. input.client_features().client_stats().avg_ignore_rate(),
  45. input_tensors[3])
  46. .ok()) {
  47. return false;
  48. }
  49. if (!tflite::task::core::PopulateTensor<float>(
  50. input.permission_features()[0].permission_stats().avg_deny_rate(),
  51. input_tensors[4])
  52. .ok()) {
  53. return false;
  54. }
  55. if (!tflite::task::core::PopulateTensor<float>(
  56. input.permission_features()[0].permission_stats().avg_dismiss_rate(),
  57. input_tensors[5])
  58. .ok()) {
  59. return false;
  60. }
  61. if (!tflite::task::core::PopulateTensor<float>(
  62. input.permission_features()[0].permission_stats().avg_grant_rate(),
  63. input_tensors[6])
  64. .ok()) {
  65. return false;
  66. }
  67. if (!tflite::task::core::PopulateTensor<float>(
  68. input.permission_features()[0].permission_stats().avg_ignore_rate(),
  69. input_tensors[7])
  70. .ok()) {
  71. return false;
  72. }
  73. if (!tflite::task::core::PopulateTensor<int64_t>(
  74. static_cast<int64_t>(input.permission_features()[0]
  75. .permission_stats()
  76. .prompts_count()),
  77. input_tensors[8])
  78. .ok()) {
  79. return false;
  80. }
  81. if (!tflite::task::core::PopulateTensor<int64_t>(
  82. static_cast<int64_t>(
  83. input.client_features().client_stats().prompts_count()),
  84. input_tensors[9])
  85. .ok()) {
  86. return false;
  87. }
  88. if (!tflite::task::core::PopulateTensor<int64_t>(
  89. static_cast<int64_t>(input.client_features().gesture_enum()),
  90. input_tensors[10])
  91. .ok()) {
  92. return false;
  93. }
  94. if (!tflite::task::core::PopulateTensor<int64_t>(
  95. static_cast<int64_t>(input.client_features().platform_enum()),
  96. input_tensors[11])
  97. .ok()) {
  98. return false;
  99. }
  100. return true;
  101. }
  102. absl::optional<GeneratePredictionsResponse>
  103. PredictionModelExecutor::Postprocess(
  104. const std::vector<const TfLiteTensor*>& output_tensors) {
  105. DCHECK(request_type_ == RequestType::kNotifications ||
  106. request_type_ == RequestType::kGeolocation);
  107. std::vector<float> data;
  108. if (!tflite::task::core::PopulateVector<float>(output_tensors[0], &data)
  109. .ok()) {
  110. return absl::nullopt;
  111. }
  112. GeneratePredictionsResponse response;
  113. float threshold = request_type_ == RequestType::kNotifications
  114. ? kNotificationPredictionsThreshold
  115. : kGeolocationPredictionsThreshold;
  116. response.mutable_prediction()
  117. ->Add()
  118. ->mutable_grant_likelihood()
  119. ->set_discretized_likelihood(
  120. data[1] >= threshold
  121. ? PermissionPrediction_Likelihood_DiscretizedLikelihood_VERY_UNLIKELY
  122. : PermissionPrediction_Likelihood_DiscretizedLikelihood_DISCRETIZED_LIKELIHOOD_UNSPECIFIED);
  123. return response;
  124. }
  125. } // namespace permissions