123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- // Copyright (c) 2018 The Chromium Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style license that can be
- // found in the LICENSE file.
- #include "components/assist_ranker/quantized_nn_classifier.h"
- #include "components/assist_ranker/nn_classifier.h"
- #include "components/assist_ranker/nn_classifier_test_util.h"
- #include "testing/gtest/include/gtest/gtest.h"
- namespace assist_ranker {
- namespace quantized_nn_classifier {
- namespace {
- using ::google::protobuf::RepeatedFieldBackInserter;
- using ::google::protobuf::RepeatedPtrField;
- using ::std::copy;
- using ::std::vector;
- void CreateLayer(const vector<int>& biases,
- const vector<vector<int>>& weights,
- float low,
- float high,
- QuantizedNNLayer* layer) {
- layer->set_biases(std::string(biases.begin(), biases.end()));
- for (const auto& i : weights) {
- layer->mutable_weights()->Add(std::string(i.begin(), i.end()));
- }
- layer->set_low(low);
- layer->set_high(high);
- }
- // Creates a QuantizedDNNClassifierModel proto using a trained set of biases and
- // weights.
- QuantizedNNClassifierModel CreateModel(
- const vector<int>& hidden_biases,
- const vector<vector<int>>& hidden_weights,
- const vector<int>& logits_biases,
- const vector<vector<int>>& logits_weights,
- float low,
- float high) {
- QuantizedNNClassifierModel model;
- CreateLayer(hidden_biases, hidden_weights, low, high,
- model.mutable_hidden_layer());
- CreateLayer(logits_biases, logits_weights, low, high,
- model.mutable_logits_layer());
- return model;
- }
- TEST(QuantizedNNClassifierTest, Dequantize) {
- const QuantizedNNClassifierModel quantized = CreateModel(
- // Hidden biases.
- {{8, 16, 32}},
- // Hidden weights.
- {{2, 4, 6}, {10, 4, 8}},
- // Logits biases.
- {2},
- // Logits weights.
- {{4}, {2}, {6}},
- // Low.
- 0,
- // High.
- 128);
- ASSERT_TRUE(Validate(quantized));
- const NNClassifierModel model = Dequantize(quantized);
- const NNClassifierModel expected = nn_classifier::CreateModel(
- // Hidden biases.
- {{4, 8, 16}},
- // Hidden weights.
- {{1, 2, 3}, {5, 2, 4}},
- // Logits biases.
- {1},
- // Logits weights.
- {{2}, {1}, {3}});
- EXPECT_EQ(model.SerializeAsString(), expected.SerializeAsString());
- }
- TEST(QuantizedNNClassifierTest, XorTest) {
- // Creates a NN with a single hidden layer of 5 units that solves XOR.
- // Creates a QuantizedDNNClassifier model containing the trained biases and
- // weights.
- const QuantizedNNClassifierModel quantized = CreateModel(
- // Hidden biases.
- {{110, 139, 175, 55, 106}},
- // Hidden weights.
- {{228, 127, 97, 217, 158}, {55, 219, 80, 199, 152}},
- // Logits biases.
- {74},
- // Logits weights.
- {{255}, {211}, {53}, {0}, {86}},
- // Low.
- -2.96390629,
- // High.
- 2.8636384);
- ASSERT_TRUE(Validate(quantized));
- const NNClassifierModel model = Dequantize(quantized);
- ASSERT_TRUE(nn_classifier::Validate(model));
- EXPECT_TRUE(nn_classifier::CheckInference(model, {0, 0}, {-2.7032}));
- EXPECT_TRUE(nn_classifier::CheckInference(model, {0, 1}, {2.80681}));
- EXPECT_TRUE(nn_classifier::CheckInference(model, {1, 0}, {2.64435}));
- EXPECT_TRUE(nn_classifier::CheckInference(model, {1, 1}, {-3.17825}));
- }
- TEST(QuantizedNNClassifierTest, ValidateQuantizedNNClassifierModel) {
- // Empty model.
- QuantizedNNClassifierModel model;
- EXPECT_FALSE(Validate(model));
- // Valid model.
- model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}},
- 0, 1);
- EXPECT_TRUE(Validate(model));
- // Hidden bias incorrect size.
- model =
- CreateModel({0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}}, 0, 1);
- EXPECT_FALSE(Validate(model));
- // Hidden weight vector incorrect size.
- model =
- CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0}}, {0}, {{0}, {0}, {0}}, 0, 1);
- EXPECT_FALSE(Validate(model));
- // Logits weights incorrect size.
- model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}}, 0, 1);
- EXPECT_FALSE(Validate(model));
- // Empty logits bias.
- model =
- CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {}, {{0}, {0}, {0}}, 0, 1);
- EXPECT_FALSE(Validate(model));
- // Low / high incorrect.
- model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}},
- 1, 0);
- EXPECT_FALSE(Validate(model));
- }
- } // namespace
- } // namespace quantized_nn_classifier
- } // namespace assist_ranker
|