Add msan helper functions to common_neon.h

The implementation differs from x86 because the VEXT instruction is
a reversal of _mm_srli_si128. Instead of pulling zeroes from the back,
the mask pushes 0xFF bytes in from the front.

Usage in film_grain_neon.cc prevents use-of-uninitialized-value errors.

PiperOrigin-RevId: 386899807
Change-Id: I93fc6b246a8555521eb3c174ecc05961f15347c2
diff --git a/src/dsp/arm/common_neon.h b/src/dsp/arm/common_neon.h
index 53a038f..c87b336 100644
--- a/src/dsp/arm/common_neon.h
+++ b/src/dsp/arm/common_neon.h
@@ -23,10 +23,13 @@
 
 #include <arm_neon.h>
 
+#include <algorithm>
 #include <cstddef>
 #include <cstdint>
 #include <cstring>
 
+#include "src/utils/compiler_attributes.h"
+
 #if 0
 #include <cstdio>
 #include <string>
@@ -248,6 +251,81 @@
 }
 
 //------------------------------------------------------------------------------
+// Load functions to avoid MemorySanitizer's use-of-uninitialized-value warning.
+
+inline uint8x8_t MaskOverreads(const uint8x8_t source,
+                               const ptrdiff_t over_read_in_bytes) {
+  uint8x8_t dst = source;
+#if LIBGAV1_MSAN
+  if (over_read_in_bytes > 0) {
+    uint8x8_t mask = vdup_n_u8(0);
+    uint8x8_t valid_element_mask = vdup_n_u8(-1);
+    const int valid_bytes =
+        std::min(8, 8 - static_cast<int>(over_read_in_bytes));
+    for (int i = 0; i < valid_bytes; ++i) {
+      // Feed ff bytes into |mask| one at a time.
+      mask = vext_u8(valid_element_mask, mask, 7);
+    }
+    dst = vand_u8(dst, mask);
+  }
+#else
+  static_cast<void>(over_read_in_bytes);
+#endif
+  return dst;
+}
+
+inline uint8x16_t MaskOverreadsQ(const uint8x16_t source,
+                                 const ptrdiff_t over_read_in_bytes) {
+  uint8x16_t dst = source;
+#if LIBGAV1_MSAN
+  if (over_read_in_bytes > 0) {
+    uint8x16_t mask = vdupq_n_u8(0);
+    uint8x16_t valid_element_mask = vdupq_n_u8(-1);
+    const int valid_bytes =
+        std::min(16, 16 - static_cast<int>(over_read_in_bytes));
+    for (int i = 0; i < valid_bytes; ++i) {
+      // Feed ff bytes into |mask| one at a time.
+      mask = vextq_u8(valid_element_mask, mask, 15);
+    }
+    dst = vandq_u8(dst, mask);
+  }
+#else
+  static_cast<void>(over_read_in_bytes);
+#endif
+  return dst;
+}
+
+inline uint8x8_t Load1MsanU8(const uint8_t* const source,
+                             const ptrdiff_t over_read_in_bytes) {
+  return MaskOverreads(vld1_u8(source), over_read_in_bytes);
+}
+
+inline uint8x16_t Load1QMsanU8(const uint8_t* const source,
+                               const ptrdiff_t over_read_in_bytes) {
+  return MaskOverreadsQ(vld1q_u8(source), over_read_in_bytes);
+}
+
+inline uint16x8_t Load1QMsanU16(const uint16_t* const source,
+                                const ptrdiff_t over_read_in_bytes) {
+  return vreinterpretq_u16_u8(MaskOverreadsQ(
+      vreinterpretq_u8_u16(vld1q_u16(source)), over_read_in_bytes));
+}
+
+inline uint16x8x2_t Load2QMsanU16(const uint16_t* const source,
+                                  const ptrdiff_t over_read_in_bytes) {
+  // Relative source index of elements (2 bytes each):
+  // dst.val[0]: 00 02 04 06 08 10 12 14
+  // dst.val[1]: 01 03 05 07 09 11 13 15
+  uint16x8x2_t dst = vld2q_u16(source);
+  dst.val[0] = vreinterpretq_u16_u8(MaskOverreadsQ(
+      vreinterpretq_u8_u16(dst.val[0]), over_read_in_bytes >> 1));
+  dst.val[1] = vreinterpretq_u16_u8(
+      MaskOverreadsQ(vreinterpretq_u8_u16(dst.val[1]),
+                     (over_read_in_bytes >> 1) + (over_read_in_bytes % 4)));
+  return dst;
+}
+
+//------------------------------------------------------------------------------
 // Store functions.
 
 // Propagate type information to the compiler. Without this the compiler may
diff --git a/src/dsp/arm/film_grain_neon.cc b/src/dsp/arm/film_grain_neon.cc
index c350970..6ce38a0 100644
--- a/src/dsp/arm/film_grain_neon.cc
+++ b/src/dsp/arm/film_grain_neon.cc
@@ -52,6 +52,10 @@
   return ZeroExtend(vld1_u8(src));
 }
 
+inline int16x8_t GetSignedSource8Msan(const uint8_t* src, int valid_range) {
+  return ZeroExtend(Load1MsanU8(src, 8 - valid_range));
+}
+
 inline void StoreUnsigned8(uint8_t* dest, const uint16x8_t data) {
   vst1_u8(dest, vmovn_u16(data));
 }
@@ -63,6 +67,10 @@
   return vreinterpretq_s16_u16(vld1q_u16(src));
 }
 
+inline int16x8_t GetSignedSource8Msan(const uint16_t* src, int valid_range) {
+  return vreinterpretq_s16_u16(Load1QMsanU16(src, 16 - valid_range));
+}
+
 inline void StoreUnsigned8(uint16_t* dest, const uint16x8_t data) {
   vst1q_u16(dest, data);
 }
@@ -184,6 +192,16 @@
   return vmovl_u8(vld1_u8(luma));
 }
 
+inline uint16x8_t GetAverageLumaMsan(const uint8_t* const luma,
+                                     int subsampling_x, int valid_range) {
+  if (subsampling_x != 0) {
+    const uint8x16_t src = Load1QMsanU8(luma, 16 - valid_range);
+
+    return vrshrq_n_u16(vpaddlq_u8(src), 1);
+  }
+  return vmovl_u8(Load1MsanU8(luma, 8 - valid_range));
+}
+
 #if LIBGAV1_MAX_BITDEPTH >= 10
 // Computes subsampled luma for use with chroma, by averaging in the x direction
 // or y direction when applicable.
@@ -223,6 +241,15 @@
   }
   return vld1q_u16(luma);
 }
+
+inline uint16x8_t GetAverageLumaMsan(const uint16_t* const luma,
+                                     int subsampling_x, int valid_range) {
+  if (subsampling_x != 0) {
+    const uint16x8x2_t src = Load2QMsanU16(luma, 32 - valid_range);
+    return vrhaddq_u16(src.val[0], src.val[1]);
+  }
+  return Load1QMsanU16(luma, 16 - valid_range);
+}
 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
 
 template <int bitdepth, typename GrainType, int auto_regression_coeff_lag,
@@ -817,11 +844,13 @@
 
     if (x < chroma_width) {
       const int luma_x = x << subsampling_x;
-      const int valid_range = width - luma_x;
-      memcpy(luma_buffer, &in_y_row[luma_x], valid_range * sizeof(in_y_row[0]));
-      luma_buffer[valid_range] = in_y_row[width - 1];
-      const uint16x8_t average_luma =
-          GetAverageLuma(luma_buffer, subsampling_x);
+      const int valid_range_pixels = width - luma_x;
+      const int valid_range_bytes = valid_range_pixels * sizeof(in_y_row[0]);
+      memcpy(luma_buffer, &in_y_row[luma_x], valid_range_bytes);
+      luma_buffer[valid_range_pixels] = in_y_row[width - 1];
+      const uint16x8_t average_luma = GetAverageLumaMsan(
+          luma_buffer, subsampling_x, valid_range_bytes + sizeof(in_y_row[0]));
+
       StoreUnsigned8(average_luma_buffer, average_luma);
 
       const int16x8_t blended =
@@ -875,13 +904,11 @@
 namespace {
 
 inline int16x8_t BlendChromaValsNoCfl(
-    const int16_t* LIBGAV1_RESTRICT scaling_lut,
-    const uint8_t* LIBGAV1_RESTRICT chroma_cursor,
+    const int16_t* LIBGAV1_RESTRICT scaling_lut, const int16x8_t orig,
     const int8_t* LIBGAV1_RESTRICT noise_image_cursor,
     const int16x8_t& average_luma, const int16x8_t& scaling_shift_vect,
     const int16x8_t& offset, int luma_multiplier, int chroma_multiplier) {
   uint8_t merged_buffer[8];
-  const int16x8_t orig = GetSignedSource8(chroma_cursor);
   const int16x8_t weighted_luma = vmulq_n_s16(average_luma, luma_multiplier);
   const int16x8_t weighted_chroma = vmulq_n_s16(orig, chroma_multiplier);
   // Maximum value of |combined_u| is 127*255 = 0x7E81.
@@ -926,10 +953,13 @@
     int x = 0;
     do {
       const int luma_x = x << subsampling_x;
+      const int valid_range = width - luma_x;
+
+      const int16x8_t orig_chroma = GetSignedSource8(&in_chroma_row[x]);
       const int16x8_t average_luma = vreinterpretq_s16_u16(
-          GetAverageLuma(&in_y_row[luma_x], subsampling_x));
+          GetAverageLumaMsan(&in_y_row[luma_x], subsampling_x, valid_range));
       const int16x8_t blended = BlendChromaValsNoCfl(
-          scaling_lut, &in_chroma_row[x], &(noise_image[y + start_height][x]),
+          scaling_lut, orig_chroma, &(noise_image[y + start_height][x]),
           average_luma, scaling_shift_vect, offset, luma_multiplier,
           chroma_multiplier);
       // In 8bpp, when params_.clip_to_restricted_range == false, we can
@@ -945,14 +975,21 @@
       // |average_luma| computation requires a duplicated luma value at the
       // end.
       const int luma_x = x << subsampling_x;
-      const int valid_range = width - luma_x;
-      memcpy(luma_buffer, &in_y_row[luma_x], valid_range * sizeof(in_y_row[0]));
-      luma_buffer[valid_range] = in_y_row[width - 1];
+      const int valid_range_pixels = width - luma_x;
+      const int valid_range_bytes = valid_range_pixels * sizeof(in_y_row[0]);
+      memcpy(luma_buffer, &in_y_row[luma_x], valid_range_bytes);
+      luma_buffer[valid_range_pixels] = in_y_row[width - 1];
+      const int valid_range_chroma_bytes =
+          (chroma_width - x) * sizeof(in_chroma_row[0]);
+      uint8_t chroma_buffer[8];
+      memcpy(chroma_buffer, &in_chroma_row[x], valid_range_chroma_bytes);
 
-      const int16x8_t average_luma =
-          vreinterpretq_s16_u16(GetAverageLuma(luma_buffer, subsampling_x));
+      const int16x8_t orig_chroma =
+          GetSignedSource8Msan(chroma_buffer, valid_range_chroma_bytes);
+      const int16x8_t average_luma = vreinterpretq_s16_u16(GetAverageLumaMsan(
+          luma_buffer, subsampling_x, valid_range_bytes + sizeof(in_y_row[0])));
       const int16x8_t blended = BlendChromaValsNoCfl(
-          scaling_lut, &in_chroma_row[x], &(noise_image[y + start_height][x]),
+          scaling_lut, orig_chroma, &(noise_image[y + start_height][x]),
           average_luma, scaling_shift_vect, offset, luma_multiplier,
           chroma_multiplier);
       StoreUnsigned8(&out_chroma_row[x],
@@ -1218,13 +1255,11 @@
 }
 
 inline int16x8_t BlendChromaValsNoCfl(
-    const int16_t* LIBGAV1_RESTRICT scaling_lut,
-    const uint16_t* LIBGAV1_RESTRICT chroma_cursor,
+    const int16_t* LIBGAV1_RESTRICT scaling_lut, const int16x8_t orig,
     const int16_t* LIBGAV1_RESTRICT noise_image_cursor,
     const int16x8_t& average_luma, const int16x8_t& scaling_shift_vect,
     const int32x4_t& offset, int luma_multiplier, int chroma_multiplier) {
   uint16_t merged_buffer[8];
-  const int16x8_t orig = GetSignedSource8(chroma_cursor);
   const int32x4_t weighted_luma_low =
       vmull_n_s16(vget_low_s16(average_luma), luma_multiplier);
   const int32x4_t weighted_luma_high =
@@ -1280,8 +1315,9 @@
       const int luma_x = x << subsampling_x;
       const int16x8_t average_luma = vreinterpretq_s16_u16(
           GetAverageLuma(&in_y_row[luma_x], subsampling_x));
+      const int16x8_t orig_chroma = GetSignedSource8(&in_chroma_row[x]);
       const int16x8_t blended = BlendChromaValsNoCfl(
-          scaling_lut, &in_chroma_row[x], &(noise_image[y + start_height][x]),
+          scaling_lut, orig_chroma, &(noise_image[y + start_height][x]),
           average_luma, scaling_shift_vect, offset, luma_multiplier,
           chroma_multiplier);
       StoreUnsigned8(&out_chroma_row[x],
@@ -1295,14 +1331,21 @@
       // |average_luma| computation requires a duplicated luma value at the
       // end.
       const int luma_x = x << subsampling_x;
-      const int valid_range = width - luma_x;
-      memcpy(luma_buffer, &in_y_row[luma_x], valid_range * sizeof(in_y_row[0]));
-      luma_buffer[valid_range] = in_y_row[width - 1];
+      const int valid_range_pixels = width - luma_x;
+      const int valid_range_bytes = valid_range_pixels * sizeof(in_y_row[0]);
+      memcpy(luma_buffer, &in_y_row[luma_x], valid_range_bytes);
+      luma_buffer[valid_range_pixels] = in_y_row[width - 1];
+      const int valid_range_chroma_bytes =
+          (chroma_width - x) * sizeof(in_chroma_row[0]);
+      uint16_t chroma_buffer[8];
+      memcpy(chroma_buffer, &in_chroma_row[x], valid_range_chroma_bytes);
+      const int16x8_t orig_chroma =
+          GetSignedSource8Msan(chroma_buffer, valid_range_chroma_bytes);
 
-      const int16x8_t average_luma =
-          vreinterpretq_s16_u16(GetAverageLuma(luma_buffer, subsampling_x));
+      const int16x8_t average_luma = vreinterpretq_s16_u16(GetAverageLumaMsan(
+          luma_buffer, subsampling_x, valid_range_bytes + sizeof(in_y_row[0])));
       const int16x8_t blended = BlendChromaValsNoCfl(
-          scaling_lut, &in_chroma_row[x], &(noise_image[y + start_height][x]),
+          scaling_lut, orig_chroma, &(noise_image[y + start_height][x]),
           average_luma, scaling_shift_vect, offset, luma_multiplier,
           chroma_multiplier);
       StoreUnsigned8(&out_chroma_row[x],