nn_classifier.cc 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. // Copyright (c) 2018 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/assist_ranker/nn_classifier.h"
  5. #include "base/check_op.h"
  6. #include "components/assist_ranker/proto/nn_classifier.pb.h"
  7. namespace assist_ranker {
  8. namespace nn_classifier {
  9. namespace {
  10. using google::protobuf::RepeatedPtrField;
  11. using std::vector;
  12. vector<float> FeedForward(const NNLayer& layer, const vector<float>& input) {
  13. const RepeatedPtrField<FloatVector>& weights = layer.weights();
  14. const FloatVector& biases = layer.biases();
  15. // Number of nodes in the layer.
  16. const int num_nodes = biases.values().size();
  17. // Number of values in the input.
  18. const int num_input = input.size();
  19. DCHECK_EQ(weights.size(), num_input);
  20. // Initialize with the bias.
  21. vector<float> output(biases.values().begin(), biases.values().end());
  22. // For each value in the input.
  23. for (int j = 0; j < num_input; ++j) {
  24. const FloatVector& v = weights[j];
  25. DCHECK_EQ(v.values().size(), num_nodes);
  26. // For each node in the layer.
  27. for (int i = 0; i < num_nodes; ++i) {
  28. output[i] += v.values(i) * input[j];
  29. }
  30. }
  31. return output;
  32. }
  33. // Apply ReLU activation function to a vector, which sets all values to
  34. // max(0, value).
  35. void Relu(vector<float>* const v) {
  36. // We are modifying the vector so the iterator must be a reference.
  37. for (float& i : *v)
  38. if (i < 0.0f)
  39. i = 0.0f;
  40. }
  41. bool ValidateLayer(const NNLayer& layer) {
  42. // Number of nodes in the layer (must be non-zero).
  43. const int num_nodes = layer.biases().values().size();
  44. if (num_nodes == 0)
  45. return false;
  46. // Number of values in the input (must be non-zero).
  47. const int num_input = layer.weights().size();
  48. if (num_input == 0)
  49. return false;
  50. for (int j = 0; j < num_input; ++j) {
  51. // The size of each weight vector must be the number of nodes in the
  52. // layer.
  53. if (layer.weights(j).values().size() != num_nodes)
  54. return false;
  55. }
  56. return true;
  57. }
  58. } // namespace
  59. bool Validate(const NNClassifierModel& model) {
  60. // Check the size of the output from the hidden layer is equal to the size
  61. // of the input in the logits layer.
  62. if (model.hidden_layer().biases().values().size() !=
  63. model.logits_layer().weights().size()) {
  64. return false;
  65. }
  66. return ValidateLayer(model.hidden_layer()) &&
  67. ValidateLayer(model.logits_layer());
  68. }
  69. vector<float> Inference(const NNClassifierModel& model,
  70. const vector<float>& input) {
  71. vector<float> v = FeedForward(model.hidden_layer(), input);
  72. Relu(&v);
  73. // Feed forward the logits layer.
  74. return FeedForward(model.logits_layer(), v);
  75. }
  76. } // namespace nn_classifier
  77. } // namespace assist_ranker