learning_fuzzertest.cc 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. // Copyright 2019 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include <fuzzer/FuzzedDataProvider.h>
  5. #include "base/test/task_environment.h"
  6. #include "media/learning/impl/learning_task_controller_impl.h"
  7. using media::learning::FeatureValue;
  8. using media::learning::FeatureVector;
  9. using media::learning::LearningTask;
  10. using ValueDescription = media::learning::LearningTask::ValueDescription;
  11. using media::learning::LearningTaskControllerImpl;
  12. using media::learning::ObservationCompletion;
  13. using media::learning::TargetValue;
  14. ValueDescription ConsumeValueDescription(FuzzedDataProvider* provider) {
  15. ValueDescription desc;
  16. desc.name = provider->ConsumeRandomLengthString(100);
  17. desc.ordering = provider->ConsumeEnum<LearningTask::Ordering>();
  18. desc.privacy_mode = provider->ConsumeEnum<LearningTask::PrivacyMode>();
  19. return desc;
  20. }
  21. double ConsumeDouble(FuzzedDataProvider* provider) {
  22. std::vector<uint8_t> v = provider->ConsumeBytes<uint8_t>(sizeof(double));
  23. if (v.size() == sizeof(double))
  24. return reinterpret_cast<double*>(v.data())[0];
  25. return 0;
  26. }
  27. FeatureVector ConsumeFeatureVector(FuzzedDataProvider* provider) {
  28. FeatureVector features;
  29. int n = provider->ConsumeIntegralInRange(0, 100);
  30. while (n-- > 0)
  31. features.push_back(FeatureValue(ConsumeDouble(provider)));
  32. return features;
  33. }
  34. extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
  35. base::test::TaskEnvironment task_environment;
  36. FuzzedDataProvider provider(data, size);
  37. LearningTask task;
  38. task.name = provider.ConsumeRandomLengthString(100);
  39. task.model = provider.ConsumeEnum<LearningTask::Model>();
  40. task.use_one_hot_conversion = provider.ConsumeBool();
  41. task.uma_hacky_aggregate_confusion_matrix = provider.ConsumeBool();
  42. task.uma_hacky_by_training_weight_confusion_matrix = provider.ConsumeBool();
  43. task.uma_hacky_by_feature_subset_confusion_matrix = provider.ConsumeBool();
  44. int n_features = provider.ConsumeIntegralInRange(0, 100);
  45. int subset_size = provider.ConsumeIntegralInRange<uint8_t>(0, n_features);
  46. if (subset_size)
  47. task.feature_subset_size = subset_size;
  48. for (int i = 0; i < n_features; i++)
  49. task.feature_descriptions.push_back(ConsumeValueDescription(&provider));
  50. task.target_description = ConsumeValueDescription(&provider);
  51. LearningTaskControllerImpl controller(task);
  52. // Build random examples.
  53. while (provider.remaining_bytes() > 0) {
  54. base::UnguessableToken id = base::UnguessableToken::Create();
  55. absl::optional<TargetValue> default_target;
  56. if (provider.ConsumeBool())
  57. default_target = TargetValue(ConsumeDouble(&provider));
  58. controller.BeginObservation(id, ConsumeFeatureVector(&provider),
  59. default_target, absl::nullopt);
  60. controller.CompleteObservation(
  61. id, ObservationCompletion(TargetValue(ConsumeDouble(&provider)),
  62. ConsumeDouble(&provider)));
  63. task_environment.RunUntilIdle();
  64. }
  65. return 0;
  66. }