csum.c 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. // SPDX-License-Identifier: GPL-2.0-only
  2. // Copyright (C) 2019-2020 Arm Ltd.
  3. #include <linux/compiler.h>
  4. #include <linux/kasan-checks.h>
  5. #include <linux/kernel.h>
  6. #include <net/checksum.h>
  7. /* Looks dumb, but generates nice-ish code */
  8. static u64 accumulate(u64 sum, u64 data)
  9. {
  10. __uint128_t tmp = (__uint128_t)sum + data;
  11. return tmp + (tmp >> 64);
  12. }
  13. /*
  14. * We over-read the buffer and this makes KASAN unhappy. Instead, disable
  15. * instrumentation and call kasan explicitly.
  16. */
  17. unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len)
  18. {
  19. unsigned int offset, shift, sum;
  20. const u64 *ptr;
  21. u64 data, sum64 = 0;
  22. if (unlikely(len == 0))
  23. return 0;
  24. offset = (unsigned long)buff & 7;
  25. /*
  26. * This is to all intents and purposes safe, since rounding down cannot
  27. * result in a different page or cache line being accessed, and @buff
  28. * should absolutely not be pointing to anything read-sensitive. We do,
  29. * however, have to be careful not to piss off KASAN, which means using
  30. * unchecked reads to accommodate the head and tail, for which we'll
  31. * compensate with an explicit check up-front.
  32. */
  33. kasan_check_read(buff, len);
  34. ptr = (u64 *)(buff - offset);
  35. len = len + offset - 8;
  36. /*
  37. * Head: zero out any excess leading bytes. Shifting back by the same
  38. * amount should be at least as fast as any other way of handling the
  39. * odd/even alignment, and means we can ignore it until the very end.
  40. */
  41. shift = offset * 8;
  42. data = *ptr++;
  43. #ifdef __LITTLE_ENDIAN
  44. data = (data >> shift) << shift;
  45. #else
  46. data = (data << shift) >> shift;
  47. #endif
  48. /*
  49. * Body: straightforward aligned loads from here on (the paired loads
  50. * underlying the quadword type still only need dword alignment). The
  51. * main loop strictly excludes the tail, so the second loop will always
  52. * run at least once.
  53. */
  54. while (unlikely(len > 64)) {
  55. __uint128_t tmp1, tmp2, tmp3, tmp4;
  56. tmp1 = *(__uint128_t *)ptr;
  57. tmp2 = *(__uint128_t *)(ptr + 2);
  58. tmp3 = *(__uint128_t *)(ptr + 4);
  59. tmp4 = *(__uint128_t *)(ptr + 6);
  60. len -= 64;
  61. ptr += 8;
  62. /* This is the "don't dump the carry flag into a GPR" idiom */
  63. tmp1 += (tmp1 >> 64) | (tmp1 << 64);
  64. tmp2 += (tmp2 >> 64) | (tmp2 << 64);
  65. tmp3 += (tmp3 >> 64) | (tmp3 << 64);
  66. tmp4 += (tmp4 >> 64) | (tmp4 << 64);
  67. tmp1 = ((tmp1 >> 64) << 64) | (tmp2 >> 64);
  68. tmp1 += (tmp1 >> 64) | (tmp1 << 64);
  69. tmp3 = ((tmp3 >> 64) << 64) | (tmp4 >> 64);
  70. tmp3 += (tmp3 >> 64) | (tmp3 << 64);
  71. tmp1 = ((tmp1 >> 64) << 64) | (tmp3 >> 64);
  72. tmp1 += (tmp1 >> 64) | (tmp1 << 64);
  73. tmp1 = ((tmp1 >> 64) << 64) | sum64;
  74. tmp1 += (tmp1 >> 64) | (tmp1 << 64);
  75. sum64 = tmp1 >> 64;
  76. }
  77. while (len > 8) {
  78. __uint128_t tmp;
  79. sum64 = accumulate(sum64, data);
  80. tmp = *(__uint128_t *)ptr;
  81. len -= 16;
  82. ptr += 2;
  83. #ifdef __LITTLE_ENDIAN
  84. data = tmp >> 64;
  85. sum64 = accumulate(sum64, tmp);
  86. #else
  87. data = tmp;
  88. sum64 = accumulate(sum64, tmp >> 64);
  89. #endif
  90. }
  91. if (len > 0) {
  92. sum64 = accumulate(sum64, data);
  93. data = *ptr;
  94. len -= 8;
  95. }
  96. /*
  97. * Tail: zero any over-read bytes similarly to the head, again
  98. * preserving odd/even alignment.
  99. */
  100. shift = len * -8;
  101. #ifdef __LITTLE_ENDIAN
  102. data = (data << shift) >> shift;
  103. #else
  104. data = (data >> shift) << shift;
  105. #endif
  106. sum64 = accumulate(sum64, data);
  107. /* Finally, folding */
  108. sum64 += (sum64 >> 32) | (sum64 << 32);
  109. sum = sum64 >> 32;
  110. sum += (sum >> 16) | (sum << 16);
  111. if (offset & 1)
  112. return (u16)swab32(sum);
  113. return sum >> 16;
  114. }
  115. __sum16 csum_ipv6_magic(const struct in6_addr *saddr,
  116. const struct in6_addr *daddr,
  117. __u32 len, __u8 proto, __wsum csum)
  118. {
  119. __uint128_t src, dst;
  120. u64 sum = (__force u64)csum;
  121. src = *(const __uint128_t *)saddr->s6_addr;
  122. dst = *(const __uint128_t *)daddr->s6_addr;
  123. sum += (__force u32)htonl(len);
  124. #ifdef __LITTLE_ENDIAN
  125. sum += (u32)proto << 24;
  126. #else
  127. sum += proto;
  128. #endif
  129. src += (src >> 64) | (src << 64);
  130. dst += (dst >> 64) | (dst << 64);
  131. sum = accumulate(sum, src >> 64);
  132. sum = accumulate(sum, dst >> 64);
  133. sum += ((sum >> 32) | (sum << 32));
  134. return csum_fold((__force __wsum)(sum >> 32));
  135. }
  136. EXPORT_SYMBOL(csum_ipv6_magic);