12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- // Copyright (c) 2018 The Chromium Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style license that can be
- // found in the LICENSE file.
- #include "components/assist_ranker/nn_classifier.h"
- #include "components/assist_ranker/nn_classifier_test_util.h"
- #include "testing/gtest/include/gtest/gtest.h"
- namespace assist_ranker {
- namespace nn_classifier {
- namespace {
- using ::google::protobuf::RepeatedFieldBackInserter;
- using ::std::copy;
- using ::std::vector;
- TEST(NNClassifierTest, XorTest) {
- // Creates a NN with a single hidden layer of 5 units that solves XOR.
- // Creates a DNNClassifier model containing the trained biases and weights.
- const NNClassifierModel model = CreateXorClassifierModel();
- ASSERT_TRUE(Validate(model));
- EXPECT_TRUE(CheckInference(model, {0, 0}, {-2.7154054}));
- EXPECT_TRUE(CheckInference(model, {0, 1}, {2.8271765}));
- EXPECT_TRUE(CheckInference(model, {1, 0}, {2.6790769}));
- EXPECT_TRUE(CheckInference(model, {1, 1}, {-3.1652793}));
- }
- TEST(NNClassifierTest, ValidateNNClassifierModel) {
- // Empty model.
- NNClassifierModel model;
- EXPECT_FALSE(Validate(model));
- // Valid model.
- model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}});
- EXPECT_TRUE(Validate(model));
- // Too few hidden layer biases.
- model = CreateModel({0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}});
- EXPECT_FALSE(Validate(model));
- // Too few hidden layer weights.
- model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0}}, {0}, {{0}, {0}, {0}});
- EXPECT_FALSE(Validate(model));
- // Too few logits weights.
- model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}});
- EXPECT_FALSE(Validate(model));
- // Logits biases empty.
- model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {}, {{0}, {0}, {0}});
- EXPECT_FALSE(Validate(model));
- }
- } // namespace
- } // namespace nn_classifier
- } // namespace assist_ranker
|