ranker_model_unittest.cc 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. // Copyright 2017 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/assist_ranker/ranker_model.h"
  5. #include <memory>
  6. #include "base/time/time.h"
  7. #include "components/assist_ranker/proto/ranker_model.pb.h"
  8. #include "testing/gtest/include/gtest/gtest.h"
  9. namespace {
  10. using assist_ranker::RankerModel;
  11. const char kModelURL[] = "https://some.url.net/model";
  12. int64_t InSeconds(const base::Time t) {
  13. return (t - base::Time()).InSeconds();
  14. }
  15. std::unique_ptr<RankerModel> NewModel(const std::string& model_url,
  16. base::Time last_modified,
  17. base::TimeDelta cache_duration) {
  18. std::unique_ptr<RankerModel> model = std::make_unique<RankerModel>();
  19. auto* metadata = model->mutable_proto()->mutable_metadata();
  20. if (!model_url.empty())
  21. metadata->set_source(model_url);
  22. if (!last_modified.is_null())
  23. metadata->set_last_modified_sec(InSeconds(last_modified));
  24. if (!cache_duration.is_zero())
  25. metadata->set_cache_duration_sec(cache_duration.InSeconds());
  26. auto* translate = model->mutable_proto()->mutable_translate();
  27. translate->set_version(1);
  28. auto* logit = translate->mutable_translate_logistic_regression_model();
  29. logit->set_bias(0.1f);
  30. logit->set_accept_ratio_weight(0.2f);
  31. logit->set_decline_ratio_weight(0.3f);
  32. logit->set_ignore_ratio_weight(0.4f);
  33. return model;
  34. }
  35. } // namespace
  36. TEST(RankerModelTest, Serialization) {
  37. base::Time last_modified = base::Time::Now();
  38. base::TimeDelta cache_duration = base::Days(3);
  39. std::unique_ptr<RankerModel> original_model =
  40. NewModel(kModelURL, last_modified, cache_duration);
  41. std::string original_model_str = original_model->SerializeAsString();
  42. std::unique_ptr<RankerModel> serialized_model =
  43. RankerModel::FromString(original_model_str);
  44. std::string serialized_model_str = serialized_model->SerializeAsString();
  45. EXPECT_EQ(serialized_model_str, original_model_str);
  46. EXPECT_EQ(serialized_model->GetSourceURL(), kModelURL);
  47. EXPECT_EQ(serialized_model->proto().metadata().last_modified_sec(),
  48. InSeconds(last_modified));
  49. EXPECT_EQ(serialized_model->proto().metadata().cache_duration_sec(),
  50. cache_duration.InSeconds());
  51. }
  52. TEST(RankerModelTest, IsExpired) {
  53. base::Time today = base::Time::Now();
  54. base::TimeDelta days_15 = base::Days(15);
  55. base::TimeDelta days_30 = base::Days(30);
  56. base::TimeDelta days_60 = base::Days(60);
  57. EXPECT_FALSE(NewModel(kModelURL, today, days_30)->IsExpired());
  58. EXPECT_FALSE(NewModel(kModelURL, today - days_15, days_30)->IsExpired());
  59. EXPECT_TRUE(NewModel(kModelURL, today - days_60, days_30)->IsExpired());
  60. }