assist_ranker_service_impl.cc 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. // Copyright 2017 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/assist_ranker/assist_ranker_service_impl.h"
  5. #include "base/memory/weak_ptr.h"
  6. #include "components/assist_ranker/binary_classifier_predictor.h"
  7. #include "components/assist_ranker/ranker_model_loader_impl.h"
  8. #include "services/network/public/cpp/shared_url_loader_factory.h"
  9. #include "url/gurl.h"
  10. namespace assist_ranker {
  11. AssistRankerServiceImpl::AssistRankerServiceImpl(
  12. base::FilePath base_path,
  13. scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory)
  14. : url_loader_factory_(std::move(url_loader_factory)),
  15. base_path_(std::move(base_path)) {}
  16. AssistRankerServiceImpl::~AssistRankerServiceImpl() {}
  17. base::WeakPtr<BinaryClassifierPredictor>
  18. AssistRankerServiceImpl::FetchBinaryClassifierPredictor(
  19. const PredictorConfig& config) {
  20. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  21. const std::string& model_name = config.model_name;
  22. auto predictor_it = predictor_map_.find(model_name);
  23. if (predictor_it != predictor_map_.end()) {
  24. DVLOG(1) << "Predictor " << model_name << " already initialized.";
  25. return base::AsWeakPtr(
  26. static_cast<BinaryClassifierPredictor*>(predictor_it->second.get()));
  27. }
  28. // The predictor does not exist yet, so we create one.
  29. DVLOG(1) << "Initializing predictor: " << model_name;
  30. std::unique_ptr<BinaryClassifierPredictor> predictor =
  31. BinaryClassifierPredictor::Create(config, GetModelPath(model_name),
  32. url_loader_factory_);
  33. base::WeakPtr<BinaryClassifierPredictor> weak_ptr =
  34. base::AsWeakPtr(predictor.get());
  35. predictor_map_[model_name] = std::move(predictor);
  36. return weak_ptr;
  37. }
  38. base::FilePath AssistRankerServiceImpl::GetModelPath(
  39. const std::string& model_filename) {
  40. if (base_path_.empty())
  41. return base::FilePath();
  42. return base_path_.AppendASCII(model_filename);
  43. }
  44. } // namespace assist_ranker