123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542 |
- // Copyright 2022 The Chromium Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style license that can be
- // found in the LICENSE file.
- #include "components/optimization_guide/core/entity_annotator_native_library.h"
- #include "base/base_paths.h"
- #include "base/compiler_specific.h"
- #include "base/logging.h"
- #include "base/memory/ptr_util.h"
- #include "base/metrics/histogram_functions.h"
- #include "base/path_service.h"
- #include "build/build_config.h"
- #include "components/optimization_guide/core/model_util.h"
- #include "components/optimization_guide/core/optimization_guide_util.h"
- #include "components/optimization_guide/proto/page_entities_model_metadata.pb.h"
- #if BUILDFLAG(IS_MAC)
- #include "base/mac/bundle_locations.h"
- #include "base/mac/foundation_util.h"
- #endif
- // IMPORTANT: All functions in this file that call dlsym()'ed
- // functions should be annotated with DISABLE_CFI_DLSYM.
- namespace optimization_guide {
- namespace {
- const char kModelMetadataBaseName[] = "model_metadata.pb";
- const char kWordEmbeddingsBaseName[] = "word_embeddings";
- const char kNameTableBaseName[] = "entities_names";
- const char kMetadataTableBaseName[] = "entities_metadata";
- const char kNameFilterBaseName[] = "entities_names_filter";
- const char kPrefixFilterBaseName[] = "entities_prefixes_filter";
- // Sets |field_to_set| with the full file path of |base_name|'s entry in
- // |base_to_full_file_path|. Returns whether |base_name| is in
- // |base_to_full_file_path|.
- absl::optional<std::string> GetFilePathFromMap(
- const std::string& base_name,
- const base::flat_map<std::string, base::FilePath>& base_to_full_file_path) {
- auto it = base_to_full_file_path.find(base_name);
- return it == base_to_full_file_path.end()
- ? absl::nullopt
- : absl::make_optional(FilePathToString(it->second));
- }
- // Returns the expected base name for |slice|. Will be of the form
- // |slice|-|base_name|.
- std::string GetSliceBaseName(const std::string& slice,
- const std::string& base_name) {
- return slice + "-" + base_name;
- }
- class ScopedEntityAnnotatorCreationStatusRecorder {
- public:
- ScopedEntityAnnotatorCreationStatusRecorder() = default;
- ~ScopedEntityAnnotatorCreationStatusRecorder() {
- DCHECK_NE(status_, EntityAnnotatorCreationStatus::kUnknown);
- base::UmaHistogramEnumeration(
- "OptimizationGuide.PageEntitiesModelExecutor.CreationStatus", status_);
- }
- void set_status(EntityAnnotatorCreationStatus status) { status_ = status; }
- private:
- EntityAnnotatorCreationStatus status_ =
- EntityAnnotatorCreationStatus::kUnknown;
- };
- } // namespace
- EntityAnnotatorNativeLibrary::EntityAnnotatorNativeLibrary(
- base::NativeLibrary native_library,
- bool should_provide_filter_path)
- : native_library_(std::move(native_library)),
- should_provide_filter_path_(should_provide_filter_path) {
- LoadFunctions();
- }
- EntityAnnotatorNativeLibrary::~EntityAnnotatorNativeLibrary() = default;
- // static
- std::unique_ptr<EntityAnnotatorNativeLibrary>
- EntityAnnotatorNativeLibrary::Create(bool should_provide_filter_path) {
- base::FilePath base_dir;
- #if !BUILDFLAG(IS_ANDROID)
- #if BUILDFLAG(IS_MAC)
- if (base::mac::AmIBundled()) {
- base_dir = base::mac::FrameworkBundlePath().Append("Libraries");
- } else {
- #endif // BUILDFLAG(IS_MAC)
- if (!base::PathService::Get(base::DIR_MODULE, &base_dir)) {
- LOG(ERROR) << "Error getting app dir";
- return nullptr;
- }
- #if BUILDFLAG(IS_MAC)
- }
- #endif // BUILDFLAG(IS_MAC)
- #endif // !BUILDFLAG(IS_ANDROID)
- base::NativeLibraryLoadError error;
- base::NativeLibrary native_library = base::LoadNativeLibrary(
- base_dir.AppendASCII(
- base::GetNativeLibraryName("optimization_guide_internal")),
- &error);
- base::UmaHistogramBoolean(
- "OptimizationGuide.EntityAnnotatorNativeLibrary.InitiatedSuccessfully",
- (native_library != nullptr));
- if (!native_library) {
- LOG(ERROR) << "Failed to initialize optimization guide internal: "
- << error.ToString();
- return nullptr;
- }
- std::unique_ptr<EntityAnnotatorNativeLibrary>
- entity_annotator_native_library =
- base::WrapUnique<EntityAnnotatorNativeLibrary>(
- new EntityAnnotatorNativeLibrary(std::move(native_library),
- should_provide_filter_path));
- if (entity_annotator_native_library->IsValid()) {
- return entity_annotator_native_library;
- }
- LOG(ERROR) << "Could not find all required functions for optimization guide "
- "internal library";
- return nullptr;
- }
- DISABLE_CFI_DLSYM
- void EntityAnnotatorNativeLibrary::LoadFunctions() {
- get_max_supported_feature_flag_func_ =
- reinterpret_cast<GetMaxSupportedFeatureFlagFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_,
- "OptimizationGuideEntityAnnotatorGetMaxSupportedFeatureFlag"));
- create_from_options_func_ = reinterpret_cast<CreateFromOptionsFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_,
- "OptimizationGuideEntityAnnotatorCreateFromOptions"));
- get_creation_error_func_ = reinterpret_cast<GetCreationErrorFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_, "OptimizationGuideEntityAnnotatorGetCreationError"));
- delete_func_ =
- reinterpret_cast<DeleteFunc>(base::GetFunctionPointerFromNativeLibrary(
- native_library_, "OptimizationGuideEntityAnnotatorDelete"));
- annotate_job_create_func_ = reinterpret_cast<AnnotateJobCreateFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_,
- "OptimizationGuideEntityAnnotatorAnnotateJobCreate"));
- annotate_job_delete_func_ = reinterpret_cast<AnnotateJobDeleteFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_,
- "OptimizationGuideEntityAnnotatorAnnotateJobDelete"));
- run_annotate_job_func_ = reinterpret_cast<RunAnnotateJobFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_, "OptimizationGuideEntityAnnotatorRunAnnotateJob"));
- annotate_get_output_metadata_at_index_func_ = reinterpret_cast<
- AnnotateGetOutputMetadataAtIndexFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_,
- "OptimizationGuideEntityAnnotatorAnnotateGetOutputMetadataAtIndex"));
- annotate_get_output_metadata_score_at_index_func_ =
- reinterpret_cast<AnnotateGetOutputMetadataScoreAtIndexFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_,
- "OptimizationGuideEntityAnnotatorAnnotateGetOutputMetadataScoreAt"
- "Index"));
- entity_metadata_job_create_func_ =
- reinterpret_cast<EntityMetadataJobCreateFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_,
- "OptimizationGuideEntityAnnotatorEntityMetadataJobCreate"));
- entity_metadata_job_delete_func_ =
- reinterpret_cast<EntityMetadataJobDeleteFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_,
- "OptimizationGuideEntityAnnotatorEntityMetadataJobDelete"));
- run_entity_metadata_job_func_ = reinterpret_cast<RunEntityMetadataJobFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_,
- "OptimizationGuideEntityAnnotatorRunEntityMetadataJob"));
- options_create_func_ = reinterpret_cast<OptionsCreateFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_, "OptimizationGuideEntityAnnotatorOptionsCreate"));
- options_set_model_file_path_func_ =
- reinterpret_cast<OptionsSetModelFilePathFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_,
- "OptimizationGuideEntityAnnotatorOptionsSetModelFilePath"));
- options_set_model_metadata_file_path_func_ = reinterpret_cast<
- OptionsSetModelMetadataFilePathFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_,
- "OptimizationGuideEntityAnnotatorOptionsSetModelMetadataFilePath"));
- options_set_word_embeddings_file_path_func_ = reinterpret_cast<
- OptionsSetWordEmbeddingsFilePathFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_,
- "OptimizationGuideEntityAnnotatorOptionsSetWordEmbeddingsFilePath"));
- options_add_model_slice_func_ = reinterpret_cast<OptionsAddModelSliceFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_,
- "OptimizationGuideEntityAnnotatorOptionsAddModelSlice"));
- options_delete_func_ = reinterpret_cast<OptionsDeleteFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_, "OptimizationGuideEntityAnnotatorOptionsDelete"));
- entity_metadata_get_entity_id_func_ =
- reinterpret_cast<EntityMetadataGetEntityIdFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_, "OptimizationGuideEntityMetadataGetEntityID"));
- entity_metadata_get_human_readable_name_func_ =
- reinterpret_cast<EntityMetadataGetHumanReadableNameFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_,
- "OptimizationGuideEntityMetadataGetHumanReadableName"));
- entity_metadata_get_human_readable_categories_count_func_ = reinterpret_cast<
- EntityMetadataGetHumanReadableCategoriesCountFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_,
- "OptimizationGuideEntityMetadataGetHumanReadableCategoriesCount"));
- entity_metadata_get_human_readable_category_name_at_index_func_ =
- reinterpret_cast<EntityMetadataGetHumanReadableCategoryNameAtIndexFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_,
- "OptimizationGuideEntityMetadataGetHumanReadableCategoryNameAtInd"
- "ex"));
- entity_metadata_get_human_readable_category_score_at_index_func_ =
- reinterpret_cast<EntityMetadataGetHumanReadableCategoryScoreAtIndexFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_,
- "OptimizationGuideEntityMetadataGetHumanReadableCategoryScoreAtIn"
- "dex"));
- entity_metadata_get_human_readable_aliases_count_func_ =
- reinterpret_cast<EntityMetadataGetHumanReadableAliasesCountFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_,
- "OptimizationGuideEntityMetadataGetHumanReadableAliasesCount"));
- entity_metadata_get_human_readable_alias_at_index_func_ =
- reinterpret_cast<EntityMetadataGetHumanReadableAliasAtIndexFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_,
- "OptimizationGuideEntityMetadataGetHumanReadableAliasAtIndex"));
- entity_metadata_get_collections_count_func_ =
- reinterpret_cast<EntityMetadataGetCollectionsCountFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_,
- "OptimizationGuideEntityMetadataGetCollectionsCount"));
- entity_metadata_get_collection_at_index_func_ =
- reinterpret_cast<EntityMetadataGetCollectionAtIndexFunc>(
- base::GetFunctionPointerFromNativeLibrary(
- native_library_,
- "OptimizationGuideEntityMetadataGetCollectionAtIndex"));
- }
- DISABLE_CFI_DLSYM
- bool EntityAnnotatorNativeLibrary::IsValid() const {
- return get_max_supported_feature_flag_func_ && create_from_options_func_ &&
- get_creation_error_func_ && delete_func_ &&
- annotate_job_create_func_ && annotate_job_delete_func_ &&
- run_annotate_job_func_ &&
- annotate_get_output_metadata_at_index_func_ &&
- annotate_get_output_metadata_score_at_index_func_ &&
- entity_metadata_job_create_func_ && entity_metadata_job_delete_func_ &&
- run_entity_metadata_job_func_ && options_create_func_ &&
- options_set_model_file_path_func_ &&
- options_set_model_metadata_file_path_func_ &&
- options_set_word_embeddings_file_path_func_ &&
- options_add_model_slice_func_ && options_delete_func_ &&
- entity_metadata_get_entity_id_func_ &&
- entity_metadata_get_human_readable_name_func_ &&
- entity_metadata_get_human_readable_categories_count_func_ &&
- entity_metadata_get_human_readable_category_name_at_index_func_ &&
- entity_metadata_get_human_readable_category_score_at_index_func_ &&
- entity_metadata_get_human_readable_aliases_count_func_ &&
- entity_metadata_get_human_readable_alias_at_index_func_ &&
- entity_metadata_get_collections_count_func_ &&
- entity_metadata_get_collection_at_index_func_;
- }
- DISABLE_CFI_DLSYM
- int32_t EntityAnnotatorNativeLibrary::GetMaxSupportedFeatureFlag() {
- DCHECK(IsValid());
- if (!IsValid()) {
- return -1;
- }
- return get_max_supported_feature_flag_func_();
- }
- DISABLE_CFI_DLSYM
- void* EntityAnnotatorNativeLibrary::CreateEntityAnnotator(
- const ModelInfo& model_info) {
- ScopedEntityAnnotatorCreationStatusRecorder recorder;
- DCHECK(IsValid());
- if (!IsValid()) {
- recorder.set_status(EntityAnnotatorCreationStatus::kLibraryInvalid);
- return nullptr;
- }
- void* options = options_create_func_();
- EntityAnnotatorCreationStatus status;
- if (!PopulateEntityAnnotatorOptionsFromModelInfo(options, model_info,
- &status)) {
- options_delete_func_(options);
- recorder.set_status(status);
- return nullptr;
- }
- void* entity_annotator = create_from_options_func_(options);
- const char* creation_error = get_creation_error_func_(entity_annotator);
- if (creation_error) {
- LOG(ERROR) << "Failed to create entity annotator: " << creation_error;
- recorder.set_status(EntityAnnotatorCreationStatus::kInitializationFailure);
- DeleteEntityAnnotator(entity_annotator);
- entity_annotator = nullptr;
- } else {
- recorder.set_status(EntityAnnotatorCreationStatus::kSuccess);
- }
- options_delete_func_(options);
- return entity_annotator;
- }
- DISABLE_CFI_DLSYM
- bool EntityAnnotatorNativeLibrary::PopulateEntityAnnotatorOptionsFromModelInfo(
- void* options,
- const ModelInfo& model_info,
- EntityAnnotatorCreationStatus* status) {
- // We don't know which files are intended for use if we don't have model
- // metadata, so return early.
- if (!model_info.GetModelMetadata()) {
- *status = EntityAnnotatorCreationStatus::kMissingModelMetadata;
- return false;
- }
- // // Validate the model metadata.
- absl::optional<proto::PageEntitiesModelMetadata> entities_model_metadata =
- ParsedAnyMetadata<proto::PageEntitiesModelMetadata>(
- model_info.GetModelMetadata().value());
- if (!entities_model_metadata) {
- *status = EntityAnnotatorCreationStatus::kMissingEntitiesModelMetadata;
- return false;
- }
- if (entities_model_metadata->slice_size() == 0) {
- *status = EntityAnnotatorCreationStatus::
- kMissingEntitiesModelMetadataSliceSpecification;
- return false;
- }
- // Build the entity annotator options.
- options_set_model_file_path_func_(
- options, FilePathToString(model_info.GetModelFilePath()).c_str());
- // Attach the additional files required by the model.
- base::flat_map<std::string, base::FilePath> base_to_full_file_path;
- for (const auto& model_file : model_info.GetAdditionalFiles()) {
- base_to_full_file_path.insert(
- {FilePathToString(model_file.BaseName()), model_file});
- }
- absl::optional<std::string> model_metadata_file_path =
- GetFilePathFromMap(kModelMetadataBaseName, base_to_full_file_path);
- if (!model_metadata_file_path) {
- *status = EntityAnnotatorCreationStatus::
- kMissingAdditionalEntitiesModelMetadataPath;
- return false;
- }
- options_set_model_metadata_file_path_func_(options,
- model_metadata_file_path->c_str());
- absl::optional<std::string> word_embeddings_file_path =
- GetFilePathFromMap(kWordEmbeddingsBaseName, base_to_full_file_path);
- if (!word_embeddings_file_path) {
- *status =
- EntityAnnotatorCreationStatus::kMissingAdditionalWordEmbeddingsPath;
- return false;
- }
- options_set_word_embeddings_file_path_func_(
- options, word_embeddings_file_path->c_str());
- base::flat_set<std::string> slices(entities_model_metadata->slice().begin(),
- entities_model_metadata->slice().end());
- for (const auto& slice_id : slices) {
- absl::optional<std::string> name_filter_path;
- if (should_provide_filter_path_) {
- name_filter_path =
- GetFilePathFromMap(GetSliceBaseName(slice_id, kNameFilterBaseName),
- base_to_full_file_path);
- if (!name_filter_path) {
- *status =
- EntityAnnotatorCreationStatus::kMissingAdditionalNameFilterPath;
- return false;
- }
- }
- absl::optional<std::string> name_table_path = GetFilePathFromMap(
- GetSliceBaseName(slice_id, kNameTableBaseName), base_to_full_file_path);
- if (!name_table_path) {
- *status = EntityAnnotatorCreationStatus::kMissingAdditionalNameTablePath;
- return false;
- }
- absl::optional<std::string> prefix_filter_path;
- if (should_provide_filter_path_) {
- prefix_filter_path =
- GetFilePathFromMap(GetSliceBaseName(slice_id, kPrefixFilterBaseName),
- base_to_full_file_path);
- if (!prefix_filter_path) {
- *status =
- EntityAnnotatorCreationStatus::kMissingAdditionalPrefixFilterPath;
- return false;
- }
- }
- absl::optional<std::string> metadata_table_path =
- GetFilePathFromMap(GetSliceBaseName(slice_id, kMetadataTableBaseName),
- base_to_full_file_path);
- if (!metadata_table_path) {
- *status =
- EntityAnnotatorCreationStatus::kMissingAdditionalMetadataTablePath;
- return false;
- }
- options_add_model_slice_func_(
- options, slice_id.c_str(), name_filter_path.value_or("").c_str(),
- name_table_path->c_str(), prefix_filter_path.value_or("").c_str(),
- metadata_table_path->c_str());
- }
- return true;
- }
- DISABLE_CFI_DLSYM
- void EntityAnnotatorNativeLibrary::DeleteEntityAnnotator(
- void* entity_annotator) {
- DCHECK(IsValid());
- if (!IsValid()) {
- return;
- }
- delete_func_(reinterpret_cast<void*>(entity_annotator));
- }
- DISABLE_CFI_DLSYM
- absl::optional<std::vector<ScoredEntityMetadata>>
- EntityAnnotatorNativeLibrary::AnnotateText(void* annotator,
- const std::string& text) {
- DCHECK(IsValid());
- if (!IsValid()) {
- return absl::nullopt;
- }
- if (!annotator) {
- return absl::nullopt;
- }
- void* job = annotate_job_create_func_(reinterpret_cast<void*>(annotator));
- int32_t output_metadata_count = run_annotate_job_func_(job, text.c_str());
- if (output_metadata_count <= 0) {
- return absl::nullopt;
- }
- std::vector<ScoredEntityMetadata> scored_md;
- scored_md.reserve(output_metadata_count);
- for (int32_t i = 0; i < output_metadata_count; i++) {
- ScoredEntityMetadata md;
- md.score = annotate_get_output_metadata_score_at_index_func_(job, i);
- md.metadata = GetEntityMetadataFromOptimizationGuideEntityMetadata(
- annotate_get_output_metadata_at_index_func_(job, i));
- scored_md.emplace_back(md);
- }
- annotate_job_delete_func_(job);
- return scored_md;
- }
- DISABLE_CFI_DLSYM
- absl::optional<EntityMetadata>
- EntityAnnotatorNativeLibrary::GetEntityMetadataForEntityId(
- void* annotator,
- const std::string& entity_id) {
- DCHECK(IsValid());
- if (!IsValid()) {
- return absl::nullopt;
- }
- if (!annotator) {
- return absl::nullopt;
- }
- void* job =
- entity_metadata_job_create_func_(reinterpret_cast<void*>(annotator));
- const void* entity_metadata =
- run_entity_metadata_job_func_(job, entity_id.c_str());
- if (!entity_metadata) {
- return absl::nullopt;
- }
- EntityMetadata md =
- GetEntityMetadataFromOptimizationGuideEntityMetadata(entity_metadata);
- entity_metadata_job_delete_func_(job);
- return md;
- }
- DISABLE_CFI_DLSYM
- EntityMetadata EntityAnnotatorNativeLibrary::
- GetEntityMetadataFromOptimizationGuideEntityMetadata(
- const void* og_entity_metadata) {
- EntityMetadata entity_metadata;
- entity_metadata.entity_id =
- entity_metadata_get_entity_id_func_(og_entity_metadata);
- entity_metadata.human_readable_name =
- entity_metadata_get_human_readable_name_func_(og_entity_metadata);
- int32_t human_readable_categories_count =
- entity_metadata_get_human_readable_categories_count_func_(
- og_entity_metadata);
- for (int32_t i = 0; i < human_readable_categories_count; i++) {
- std::string category_name =
- entity_metadata_get_human_readable_category_name_at_index_func_(
- og_entity_metadata, i);
- float category_score =
- entity_metadata_get_human_readable_category_score_at_index_func_(
- og_entity_metadata, i);
- entity_metadata.human_readable_categories[category_name] = category_score;
- }
- int32_t human_readable_aliases_count =
- entity_metadata_get_human_readable_aliases_count_func_(
- og_entity_metadata);
- for (int32_t i = 0; i < human_readable_aliases_count; i++) {
- entity_metadata.human_readable_aliases.push_back(
- entity_metadata_get_human_readable_alias_at_index_func_(
- og_entity_metadata, i));
- }
- int32_t collections_count =
- entity_metadata_get_collections_count_func_(og_entity_metadata);
- for (int32_t i = 0; i < collections_count; i++) {
- std::string collection =
- entity_metadata_get_collection_at_index_func_(og_entity_metadata, i);
- entity_metadata.collections.push_back(collection);
- }
- return entity_metadata;
- }
- } // namespace optimization_guide
|