quantized_nn_classifier_unittest.cc 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. // Copyright (c) 2018 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/assist_ranker/quantized_nn_classifier.h"
  5. #include "components/assist_ranker/nn_classifier.h"
  6. #include "components/assist_ranker/nn_classifier_test_util.h"
  7. #include "testing/gtest/include/gtest/gtest.h"
  8. namespace assist_ranker {
  9. namespace quantized_nn_classifier {
  10. namespace {
  11. using ::google::protobuf::RepeatedFieldBackInserter;
  12. using ::google::protobuf::RepeatedPtrField;
  13. using ::std::copy;
  14. using ::std::vector;
  15. void CreateLayer(const vector<int>& biases,
  16. const vector<vector<int>>& weights,
  17. float low,
  18. float high,
  19. QuantizedNNLayer* layer) {
  20. layer->set_biases(std::string(biases.begin(), biases.end()));
  21. for (const auto& i : weights) {
  22. layer->mutable_weights()->Add(std::string(i.begin(), i.end()));
  23. }
  24. layer->set_low(low);
  25. layer->set_high(high);
  26. }
  27. // Creates a QuantizedDNNClassifierModel proto using a trained set of biases and
  28. // weights.
  29. QuantizedNNClassifierModel CreateModel(
  30. const vector<int>& hidden_biases,
  31. const vector<vector<int>>& hidden_weights,
  32. const vector<int>& logits_biases,
  33. const vector<vector<int>>& logits_weights,
  34. float low,
  35. float high) {
  36. QuantizedNNClassifierModel model;
  37. CreateLayer(hidden_biases, hidden_weights, low, high,
  38. model.mutable_hidden_layer());
  39. CreateLayer(logits_biases, logits_weights, low, high,
  40. model.mutable_logits_layer());
  41. return model;
  42. }
  43. TEST(QuantizedNNClassifierTest, Dequantize) {
  44. const QuantizedNNClassifierModel quantized = CreateModel(
  45. // Hidden biases.
  46. {{8, 16, 32}},
  47. // Hidden weights.
  48. {{2, 4, 6}, {10, 4, 8}},
  49. // Logits biases.
  50. {2},
  51. // Logits weights.
  52. {{4}, {2}, {6}},
  53. // Low.
  54. 0,
  55. // High.
  56. 128);
  57. ASSERT_TRUE(Validate(quantized));
  58. const NNClassifierModel model = Dequantize(quantized);
  59. const NNClassifierModel expected = nn_classifier::CreateModel(
  60. // Hidden biases.
  61. {{4, 8, 16}},
  62. // Hidden weights.
  63. {{1, 2, 3}, {5, 2, 4}},
  64. // Logits biases.
  65. {1},
  66. // Logits weights.
  67. {{2}, {1}, {3}});
  68. EXPECT_EQ(model.SerializeAsString(), expected.SerializeAsString());
  69. }
  70. TEST(QuantizedNNClassifierTest, XorTest) {
  71. // Creates a NN with a single hidden layer of 5 units that solves XOR.
  72. // Creates a QuantizedDNNClassifier model containing the trained biases and
  73. // weights.
  74. const QuantizedNNClassifierModel quantized = CreateModel(
  75. // Hidden biases.
  76. {{110, 139, 175, 55, 106}},
  77. // Hidden weights.
  78. {{228, 127, 97, 217, 158}, {55, 219, 80, 199, 152}},
  79. // Logits biases.
  80. {74},
  81. // Logits weights.
  82. {{255}, {211}, {53}, {0}, {86}},
  83. // Low.
  84. -2.96390629,
  85. // High.
  86. 2.8636384);
  87. ASSERT_TRUE(Validate(quantized));
  88. const NNClassifierModel model = Dequantize(quantized);
  89. ASSERT_TRUE(nn_classifier::Validate(model));
  90. EXPECT_TRUE(nn_classifier::CheckInference(model, {0, 0}, {-2.7032}));
  91. EXPECT_TRUE(nn_classifier::CheckInference(model, {0, 1}, {2.80681}));
  92. EXPECT_TRUE(nn_classifier::CheckInference(model, {1, 0}, {2.64435}));
  93. EXPECT_TRUE(nn_classifier::CheckInference(model, {1, 1}, {-3.17825}));
  94. }
  95. TEST(QuantizedNNClassifierTest, ValidateQuantizedNNClassifierModel) {
  96. // Empty model.
  97. QuantizedNNClassifierModel model;
  98. EXPECT_FALSE(Validate(model));
  99. // Valid model.
  100. model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}},
  101. 0, 1);
  102. EXPECT_TRUE(Validate(model));
  103. // Hidden bias incorrect size.
  104. model =
  105. CreateModel({0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}}, 0, 1);
  106. EXPECT_FALSE(Validate(model));
  107. // Hidden weight vector incorrect size.
  108. model =
  109. CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0}}, {0}, {{0}, {0}, {0}}, 0, 1);
  110. EXPECT_FALSE(Validate(model));
  111. // Logits weights incorrect size.
  112. model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}}, 0, 1);
  113. EXPECT_FALSE(Validate(model));
  114. // Empty logits bias.
  115. model =
  116. CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {}, {{0}, {0}, {0}}, 0, 1);
  117. EXPECT_FALSE(Validate(model));
  118. // Low / high incorrect.
  119. model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}},
  120. 1, 0);
  121. EXPECT_FALSE(Validate(model));
  122. }
  123. } // namespace
  124. } // namespace quantized_nn_classifier
  125. } // namespace assist_ranker