1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- // Copyright (c) 2018 The Chromium Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style license that can be
- // found in the LICENSE file.
- #include "components/assist_ranker/nn_classifier.h"
- #include "base/check_op.h"
- #include "components/assist_ranker/proto/nn_classifier.pb.h"
- namespace assist_ranker {
- namespace nn_classifier {
- namespace {
- using google::protobuf::RepeatedPtrField;
- using std::vector;
- vector<float> FeedForward(const NNLayer& layer, const vector<float>& input) {
- const RepeatedPtrField<FloatVector>& weights = layer.weights();
- const FloatVector& biases = layer.biases();
- // Number of nodes in the layer.
- const int num_nodes = biases.values().size();
- // Number of values in the input.
- const int num_input = input.size();
- DCHECK_EQ(weights.size(), num_input);
- // Initialize with the bias.
- vector<float> output(biases.values().begin(), biases.values().end());
- // For each value in the input.
- for (int j = 0; j < num_input; ++j) {
- const FloatVector& v = weights[j];
- DCHECK_EQ(v.values().size(), num_nodes);
- // For each node in the layer.
- for (int i = 0; i < num_nodes; ++i) {
- output[i] += v.values(i) * input[j];
- }
- }
- return output;
- }
- // Apply ReLU activation function to a vector, which sets all values to
- // max(0, value).
- void Relu(vector<float>* const v) {
- // We are modifying the vector so the iterator must be a reference.
- for (float& i : *v)
- if (i < 0.0f)
- i = 0.0f;
- }
- bool ValidateLayer(const NNLayer& layer) {
- // Number of nodes in the layer (must be non-zero).
- const int num_nodes = layer.biases().values().size();
- if (num_nodes == 0)
- return false;
- // Number of values in the input (must be non-zero).
- const int num_input = layer.weights().size();
- if (num_input == 0)
- return false;
- for (int j = 0; j < num_input; ++j) {
- // The size of each weight vector must be the number of nodes in the
- // layer.
- if (layer.weights(j).values().size() != num_nodes)
- return false;
- }
- return true;
- }
- } // namespace
- bool Validate(const NNClassifierModel& model) {
- // Check the size of the output from the hidden layer is equal to the size
- // of the input in the logits layer.
- if (model.hidden_layer().biases().values().size() !=
- model.logits_layer().weights().size()) {
- return false;
- }
- return ValidateLayer(model.hidden_layer()) &&
- ValidateLayer(model.logits_layer());
- }
- vector<float> Inference(const NNClassifierModel& model,
- const vector<float>& input) {
- vector<float> v = FeedForward(model.hidden_layer(), input);
- Relu(&v);
- // Feed forward the logits layer.
- return FeedForward(model.logits_layer(), v);
- }
- } // namespace nn_classifier
- } // namespace assist_ranker
|