nn_classifier_unittest.cc 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. // Copyright (c) 2018 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/assist_ranker/nn_classifier.h"
  5. #include "components/assist_ranker/nn_classifier_test_util.h"
  6. #include "testing/gtest/include/gtest/gtest.h"
  7. namespace assist_ranker {
  8. namespace nn_classifier {
  9. namespace {
  10. using ::google::protobuf::RepeatedFieldBackInserter;
  11. using ::std::copy;
  12. using ::std::vector;
  13. TEST(NNClassifierTest, XorTest) {
  14. // Creates a NN with a single hidden layer of 5 units that solves XOR.
  15. // Creates a DNNClassifier model containing the trained biases and weights.
  16. const NNClassifierModel model = CreateXorClassifierModel();
  17. ASSERT_TRUE(Validate(model));
  18. EXPECT_TRUE(CheckInference(model, {0, 0}, {-2.7154054}));
  19. EXPECT_TRUE(CheckInference(model, {0, 1}, {2.8271765}));
  20. EXPECT_TRUE(CheckInference(model, {1, 0}, {2.6790769}));
  21. EXPECT_TRUE(CheckInference(model, {1, 1}, {-3.1652793}));
  22. }
  23. TEST(NNClassifierTest, ValidateNNClassifierModel) {
  24. // Empty model.
  25. NNClassifierModel model;
  26. EXPECT_FALSE(Validate(model));
  27. // Valid model.
  28. model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}});
  29. EXPECT_TRUE(Validate(model));
  30. // Too few hidden layer biases.
  31. model = CreateModel({0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}});
  32. EXPECT_FALSE(Validate(model));
  33. // Too few hidden layer weights.
  34. model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0}}, {0}, {{0}, {0}, {0}});
  35. EXPECT_FALSE(Validate(model));
  36. // Too few logits weights.
  37. model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}});
  38. EXPECT_FALSE(Validate(model));
  39. // Logits biases empty.
  40. model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {}, {{0}, {0}, {0}});
  41. EXPECT_FALSE(Validate(model));
  42. }
  43. } // namespace
  44. } // namespace nn_classifier
  45. } // namespace assist_ranker