libgav1: update to cl/267700628

this improves performance by ~30% from the last snapshot

Bug: 130249450
Bug: 140837331
Test: aosp_arm-eng aosp_arm64-eng aosp_x86-eng aosp_x86_64-eng build
Test: cts-tradefed run commandAndExit cts-dev --include-filter CtsMediaTestCases --module-arg "CtsMediaTestCases:include-filter:android.media.cts.DecoderTest#testAV1*" --module-arg "CtsMediaTestCases:include-filter:android.media.cts.AdaptivePlaybackTest#testAV1*" (bonito-userdebug)
Test: cts-tradefed run commandAndExit cts-dev -a {armeabi-v7a,arm64-v8a} --include-filter CtsMediaTestCases --module-arg "CtsMediaTestCases:include-filter:android.media.cts.DecoderTest#testAV1*" --module-arg "CtsMediaTestCases:include-filter:android.media.cts.AdaptivePlaybackTest#testAV1*" (bonito_hwasan-userdebug)
Test: (externally) libaom/Argon/Allegro test vectors, cpu/memory performance, fuzzing

Change-Id: I5747e8c6a67450dd6d0ad6800f26100a36df0e26
(cherry picked from commit 2bf918d9e5232a973c834ad56ca810277c467e7b)
diff --git a/Android.bp b/Android.bp
index 30eca18..e438d44 100644
--- a/Android.bp
+++ b/Android.bp
@@ -40,6 +40,7 @@
         "libgav1/src/buffer_pool.cc",
         "libgav1/src/decoder.cc",
         "libgav1/src/decoder_impl.cc",
+        "libgav1/src/decoder_scratch_buffer.cc",
         "libgav1/src/dsp/arm/average_blend_neon.cc",
         "libgav1/src/dsp/arm/convolve_neon.cc",
         "libgav1/src/dsp/arm/distance_weighted_blend_neon.cc",
diff --git a/libgav1/src/decoder_impl.cc b/libgav1/src/decoder_impl.cc
index dbfaf4c..9448247 100644
--- a/libgav1/src/decoder_impl.cc
+++ b/libgav1/src/decoder_impl.cc
@@ -552,7 +552,8 @@
           prev_segment_ids, &post_filter, &block_parameters_holder, &cdef_index,
           &inter_transform_sizes_, dsp,
           threading_strategy_.row_thread_pool(tile_index++),
-          residual_buffer_pool_.get(), &pending_tiles));
+          residual_buffer_pool_.get(), &decoder_scratch_buffer_pool_,
+          &pending_tiles));
       if (tile == nullptr) {
         LIBGAV1_DLOG(ERROR, "Failed to allocate tile.");
         return kLibgav1StatusOutOfMemory;
diff --git a/libgav1/src/decoder_impl.h b/libgav1/src/decoder_impl.h
index c4ea526..4e4c6b5 100644
--- a/libgav1/src/decoder_impl.h
+++ b/libgav1/src/decoder_impl.h
@@ -138,6 +138,7 @@
   AlignedUniquePtr<uint8_t> threaded_window_buffer_;
   size_t threaded_window_buffer_size_ = 0;
   Array2D<TransformSize> inter_transform_sizes_;
+  DecoderScratchBufferPool decoder_scratch_buffer_pool_;
 
   LoopFilterMask loop_filter_mask_;
 
diff --git a/libgav1/src/decoder_scratch_buffer.cc b/libgav1/src/decoder_scratch_buffer.cc
new file mode 100644
index 0000000..ea897e2
--- /dev/null
+++ b/libgav1/src/decoder_scratch_buffer.cc
@@ -0,0 +1,9 @@
+#include "src/decoder_scratch_buffer.h"
+
+namespace libgav1 {
+
+// static
+constexpr int DecoderScratchBuffer::kBlockDecodedStride;
+constexpr int DecoderScratchBuffer::kPixelSize;
+
+}  // namespace libgav1
diff --git a/libgav1/src/decoder_scratch_buffer.h b/libgav1/src/decoder_scratch_buffer.h
new file mode 100644
index 0000000..9160f30
--- /dev/null
+++ b/libgav1/src/decoder_scratch_buffer.h
@@ -0,0 +1,115 @@
+#ifndef LIBGAV1_SRC_DECODER_SCRATCH_BUFFER_H_
+#define LIBGAV1_SRC_DECODER_SCRATCH_BUFFER_H_
+
+#include <cstdint>
+#include <mutex>  // NOLINT (unapproved c++11 header)
+
+#include "src/dsp/constants.h"
+#include "src/utils/compiler_attributes.h"
+#include "src/utils/constants.h"
+#include "src/utils/memory.h"
+#include "src/utils/stack.h"
+
+namespace libgav1 {
+
+// Buffer to facilitate decoding a superblock.
+struct DecoderScratchBuffer : public Allocable {
+  static constexpr int kBlockDecodedStride = 34;
+
+ private:
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  static constexpr int kPixelSize = 2;
+#else
+  static constexpr int kPixelSize = 1;
+#endif
+
+ public:
+  // The following prediction modes need a prediction mask:
+  // kCompoundPredictionTypeDiffWeighted, kCompoundPredictionTypeWedge,
+  // kCompoundPredictionTypeIntra. They are mutually exclusive. This buffer is
+  // used to store the prediction mask during the inter prediction process. The
+  // mask only needs to be created for the Y plane and is used for the U & V
+  // planes.
+  alignas(kMaxAlignment) uint8_t
+      prediction_mask[kMaxSuperBlockSizeSquareInPixels];
+
+  // For each instance of the DecoderScratchBuffer, only one of the following
+  // buffers will be used at any given time, so it is ok to share them in a
+  // union.
+  union {
+    // Union usage note: This is used only by functions in the "inter"
+    // prediction path.
+    //
+    // Buffers used for inter prediction process.
+    alignas(kMaxAlignment) uint16_t
+        prediction_buffer[2][kMaxSuperBlockSizeSquareInPixels];
+
+    struct {
+      // Union usage note: This is used only by functions in the "intra"
+      // prediction path.
+      //
+      // Buffer used for storing subsampled luma samples needed for CFL
+      // prediction. This buffer is used to avoid repetition of the subsampling
+      // for the V plane when it is already done for the U plane.
+      int16_t cfl_luma_buffer[kCflLumaBufferStride][kCflLumaBufferStride];
+
+      // Union usage note: This is used only by the
+      // Tile::ReadTransformCoefficients() function (and the helper functions
+      // that it calls). This cannot be shared with |cfl_luma_buffer| since
+      // |cfl_luma_buffer| has to live across the 3 plane loop in
+      // Tile::TransformBlock.
+      //
+      // Buffer used by Tile::ReadTransformCoefficients() to store the quantized
+      // coefficients until the dequantization process is performed.
+      int32_t quantized_buffer[kQuantizedCoefficientBufferSize];
+    };
+  };
+
+  // Buffer used for convolve. The maximum size required for this buffer is:
+  //  maximum block height (with scaling) = 2 * 128 = 256.
+  //  maximum block stride (with scaling and border aligned to 16) =
+  //     (2 * 128 + 7 + 9) * pixel_size = 272 * pixel_size.
+  alignas(kMaxAlignment) uint8_t
+      convolve_block_buffer[256 * 272 * DecoderScratchBuffer::kPixelSize];
+
+  // Flag indicating whether the data in |cfl_luma_buffer| is valid.
+  bool cfl_luma_buffer_valid;
+
+  // Equivalent to BlockDecoded array in the spec. This stores the decoded
+  // state of every 4x4 block in a superblock. It has 1 row/column border on
+  // all 4 sides (hence the 34x34 dimension instead of 32x32). Note that the
+  // spec uses "-1" as an index to access the left and top borders. In the
+  // code, we treat the index (1, 1) as equivalent to the spec's (0, 0). So
+  // all accesses into this array will be offset by +1 when compared with the
+  // spec.
+  bool block_decoded[kMaxPlanes][kBlockDecodedStride][kBlockDecodedStride];
+};
+
+class DecoderScratchBufferPool {
+ public:
+  std::unique_ptr<DecoderScratchBuffer> Get() {
+    std::lock_guard<std::mutex> lock(mutex_);
+    if (buffers_.Empty()) {
+      std::unique_ptr<DecoderScratchBuffer> scratch_buffer(
+          new (std::nothrow) DecoderScratchBuffer);
+      return scratch_buffer;
+    }
+    return buffers_.Pop();
+  }
+
+  void Release(std::unique_ptr<DecoderScratchBuffer> scratch_buffer) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    buffers_.Push(std::move(scratch_buffer));
+  }
+
+ private:
+  std::mutex mutex_;
+  // We will never need more than kMaxThreads scratch buffers since that is the
+  // maximum amount of work that will be done at any given time.
+  Stack<std::unique_ptr<DecoderScratchBuffer>, kMaxThreads> buffers_
+      LIBGAV1_GUARDED_BY(mutex_);
+};
+
+}  // namespace libgav1
+
+#endif  // LIBGAV1_SRC_DECODER_SCRATCH_BUFFER_H_
diff --git a/libgav1/src/dsp/arm/common_neon.h b/libgav1/src/dsp/arm/common_neon.h
index 2ba35ed..2a5046e 100644
--- a/libgav1/src/dsp/arm/common_neon.h
+++ b/libgav1/src/dsp/arm/common_neon.h
@@ -188,6 +188,10 @@
   memcpy(buf, &val, 4);
 }
 
+inline void Uint32ToMem(uint16_t* const buf, uint32_t val) {
+  memcpy(buf, &val, 4);
+}
+
 // Store 4 uint8_t values from the low half of a uint8x8_t register.
 inline void StoreLo4(uint8_t* const buf, const uint8x8_t val) {
   Uint32ToMem(buf, vget_lane_u32(vreinterpret_u32_u8(val), 0));
@@ -198,6 +202,13 @@
   Uint32ToMem(buf, vget_lane_u32(vreinterpret_u32_u8(val), 1));
 }
 
+// Store 2 uint16_t values from |lane| * 2 and |lane| * 2 + 1 of a uint16x8_t
+// register.
+template <int lane>
+inline void Store2(uint16_t* const buf, const uint16x8_t val) {
+  Uint32ToMem(buf, vgetq_lane_u32(vreinterpretq_u32_u16(val), lane));
+}
+
 //------------------------------------------------------------------------------
 // Bit manipulation.
 
diff --git a/libgav1/src/dsp/arm/convolve_neon.cc b/libgav1/src/dsp/arm/convolve_neon.cc
index 4126c41..13390b9 100644
--- a/libgav1/src/dsp/arm/convolve_neon.cc
+++ b/libgav1/src/dsp/arm/convolve_neon.cc
@@ -169,6 +169,285 @@
   return vreinterpretq_u16_s16(sum);
 }
 
+template <int num_taps, int filter_index, bool negative_outside_taps = true>
+uint16x8_t SumCompoundHorizontalTaps(const uint8_t* const src,
+                                     const uint8x8_t* const v_tap) {
+  // Start with an offset to guarantee the sum is non negative.
+  uint16x8_t v_sum = vdupq_n_u16(1 << 14);
+  uint8x16_t v_src[8];
+  v_src[0] = vld1q_u8(&src[0]);
+  if (num_taps == 8) {
+    v_src[1] = vextq_u8(v_src[0], v_src[0], 1);
+    v_src[2] = vextq_u8(v_src[0], v_src[0], 2);
+    v_src[3] = vextq_u8(v_src[0], v_src[0], 3);
+    v_src[4] = vextq_u8(v_src[0], v_src[0], 4);
+    v_src[5] = vextq_u8(v_src[0], v_src[0], 5);
+    v_src[6] = vextq_u8(v_src[0], v_src[0], 6);
+    v_src[7] = vextq_u8(v_src[0], v_src[0], 7);
+
+    // tap signs : - + - + + - + -
+    v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[0]), v_tap[0]);
+    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[1]), v_tap[1]);
+    v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
+    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
+    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
+    v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
+    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[6]), v_tap[6]);
+    v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[7]), v_tap[7]);
+  } else if (num_taps == 6) {
+    v_src[1] = vextq_u8(v_src[0], v_src[0], 1);
+    v_src[2] = vextq_u8(v_src[0], v_src[0], 2);
+    v_src[3] = vextq_u8(v_src[0], v_src[0], 3);
+    v_src[4] = vextq_u8(v_src[0], v_src[0], 4);
+    v_src[5] = vextq_u8(v_src[0], v_src[0], 5);
+    v_src[6] = vextq_u8(v_src[0], v_src[0], 6);
+    if (filter_index == 0) {
+      // tap signs : + - + + - +
+      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[1]), v_tap[1]);
+      v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
+      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
+      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
+      v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
+      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[6]), v_tap[6]);
+    } else {
+      if (negative_outside_taps) {
+        // tap signs : - + + + + -
+        v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[1]), v_tap[1]);
+        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
+        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
+        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
+        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
+        v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[6]), v_tap[6]);
+      } else {
+        // tap signs : + + + + + +
+        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[1]), v_tap[1]);
+        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
+        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
+        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
+        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
+        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[6]), v_tap[6]);
+      }
+    }
+  } else if (num_taps == 4) {
+    v_src[2] = vextq_u8(v_src[0], v_src[0], 2);
+    v_src[3] = vextq_u8(v_src[0], v_src[0], 3);
+    v_src[4] = vextq_u8(v_src[0], v_src[0], 4);
+    v_src[5] = vextq_u8(v_src[0], v_src[0], 5);
+    if (filter_index == 4) {
+      // tap signs : - + + -
+      v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
+      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
+      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
+      v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
+    } else {
+      // tap signs : + + + +
+      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
+      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
+      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
+      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
+    }
+  } else {
+    assert(num_taps == 2);
+    v_src[3] = vextq_u8(v_src[0], v_src[0], 3);
+    v_src[4] = vextq_u8(v_src[0], v_src[0], 4);
+    // tap signs : + +
+    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
+    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
+  }
+
+  return v_sum;
+}
+
+template <int num_taps, int filter_index>
+uint16x8_t SumHorizontalTaps2xH(const uint8_t* src, const ptrdiff_t src_stride,
+                                const uint8x8_t* const v_tap) {
+  constexpr int positive_offset_bits = kBitdepth8 + kFilterBits - 1;
+  uint16x8_t sum = vdupq_n_u16(1 << positive_offset_bits);
+  uint8x8_t input0 = vld1_u8(src);
+  src += src_stride;
+  uint8x8_t input1 = vld1_u8(src);
+  uint8x8x2_t input = vzip_u8(input0, input1);
+
+  if (num_taps == 2) {
+    // tap signs : + +
+    sum = vmlal_u8(sum, vext_u8(input.val[0], input.val[1], 6), v_tap[3]);
+    sum = vmlal_u8(sum, input.val[1], v_tap[4]);
+  } else if (filter_index == 4) {
+    // tap signs : - + + -
+    sum = vmlsl_u8(sum, RightShift<4 * 8>(input.val[0]), v_tap[2]);
+    sum = vmlal_u8(sum, vext_u8(input.val[0], input.val[1], 6), v_tap[3]);
+    sum = vmlal_u8(sum, input.val[1], v_tap[4]);
+    sum = vmlsl_u8(sum, RightShift<2 * 8>(input.val[1]), v_tap[5]);
+  } else {
+    // tap signs : + + + +
+    sum = vmlal_u8(sum, RightShift<4 * 8>(input.val[0]), v_tap[2]);
+    sum = vmlal_u8(sum, vext_u8(input.val[0], input.val[1], 6), v_tap[3]);
+    sum = vmlal_u8(sum, input.val[1], v_tap[4]);
+    sum = vmlal_u8(sum, RightShift<2 * 8>(input.val[1]), v_tap[5]);
+  }
+
+  return vrshrq_n_u16(sum, kInterRoundBitsHorizontal);
+}
+
+// TODO(johannkoenig): Rename this function. It works for more than just
+// compound convolutions.
+template <int num_taps, int step, int filter_index,
+          bool negative_outside_taps = true, bool is_2d = false,
+          bool is_8bit = false>
+void ConvolveCompoundHorizontalBlock(const uint8_t* src,
+                                     const ptrdiff_t src_stride,
+                                     void* const dest,
+                                     const ptrdiff_t pred_stride,
+                                     const int width, const int height,
+                                     const uint8x8_t* const v_tap) {
+  const uint16x8_t v_compound_round_offset = vdupq_n_u16(1 << (kBitdepth8 + 4));
+  const int16x8_t v_inter_round_bits_0 =
+      vdupq_n_s16(-kInterRoundBitsHorizontal);
+
+  auto* dest8 = static_cast<uint8_t*>(dest);
+  auto* dest16 = static_cast<uint16_t*>(dest);
+
+  if (width > 4) {
+    int y = 0;
+    do {
+      int x = 0;
+      do {
+        uint16x8_t v_sum =
+            SumCompoundHorizontalTaps<num_taps, filter_index,
+                                      negative_outside_taps>(&src[x], v_tap);
+        if (is_8bit) {
+          // Split shifts the way they are in C. They can be combined but that
+          // makes removing the 1 << 14 offset much more difficult.
+          v_sum = vrshrq_n_u16(v_sum, kInterRoundBitsHorizontal);
+          int16x8_t v_sum_signed = vreinterpretq_s16_u16(vsubq_u16(
+              v_sum, vdupq_n_u16(1 << (14 - kInterRoundBitsHorizontal))));
+          uint8x8_t result = vqrshrun_n_s16(
+              v_sum_signed, kFilterBits - kInterRoundBitsHorizontal);
+          vst1_u8(&dest8[x], result);
+        } else {
+          v_sum = vrshlq_u16(v_sum, v_inter_round_bits_0);
+          if (!is_2d) {
+            v_sum = vaddq_u16(v_sum, v_compound_round_offset);
+          }
+          vst1q_u16(&dest16[x], v_sum);
+        }
+        x += step;
+      } while (x < width);
+      src += src_stride;
+      dest8 += pred_stride;
+      dest16 += pred_stride;
+    } while (++y < height);
+    return;
+  } else if (width == 4) {
+    int y = 0;
+    do {
+      uint16x8_t v_sum =
+          SumCompoundHorizontalTaps<num_taps, filter_index,
+                                    negative_outside_taps>(&src[0], v_tap);
+      if (is_8bit) {
+        v_sum = vrshrq_n_u16(v_sum, kInterRoundBitsHorizontal);
+        int16x8_t v_sum_signed = vreinterpretq_s16_u16(vsubq_u16(
+            v_sum, vdupq_n_u16(1 << (14 - kInterRoundBitsHorizontal))));
+        uint8x8_t result = vqrshrun_n_s16(
+            v_sum_signed, kFilterBits - kInterRoundBitsHorizontal);
+        StoreLo4(&dest8[0], result);
+      } else {
+        v_sum = vrshlq_u16(v_sum, v_inter_round_bits_0);
+        if (!is_2d) {
+          v_sum = vaddq_u16(v_sum, v_compound_round_offset);
+        }
+        vst1_u16(&dest16[0], vget_low_u16(v_sum));
+      }
+      src += src_stride;
+      dest8 += pred_stride;
+      dest16 += pred_stride;
+    } while (++y < height);
+    return;
+  }
+
+  // Horizontal passes only needs to account for |num_taps| 2 and 4 when
+  // |width| == 2.
+  assert(width == 2);
+  assert(num_taps <= 4);
+
+  constexpr int positive_offset_bits = kBitdepth8 + kFilterBits - 1;
+  // Leave off + 1 << (kBitdepth8 + 3).
+  constexpr int compound_round_offset = 1 << (kBitdepth8 + 4);
+
+  if (num_taps <= 4) {
+    int y = 0;
+    do {
+      // TODO(johannkoenig): Re-order the values for storing.
+      uint16x8_t sum =
+          SumHorizontalTaps2xH<num_taps, filter_index>(src, src_stride, v_tap);
+
+      if (is_2d) {
+        dest16[0] = vgetq_lane_u16(sum, 0);
+        dest16[1] = vgetq_lane_u16(sum, 2);
+        dest16 += pred_stride;
+        dest16[0] = vgetq_lane_u16(sum, 1);
+        dest16[1] = vgetq_lane_u16(sum, 3);
+        dest16 += pred_stride;
+      } else if (!is_8bit) {
+        // None of the test vectors hit this path but the unit tests do.
+        sum = vaddq_u16(sum, vdupq_n_u16(compound_round_offset));
+
+        dest16[0] = vgetq_lane_u16(sum, 0);
+        dest16[1] = vgetq_lane_u16(sum, 2);
+        dest16 += pred_stride;
+        dest16[0] = vgetq_lane_u16(sum, 1);
+        dest16[1] = vgetq_lane_u16(sum, 3);
+        dest16 += pred_stride;
+      } else {
+        // Split shifts the way they are in C. They can be combined but that
+        // makes removing the 1 << 14 offset much more difficult.
+        int16x8_t sum_signed = vreinterpretq_s16_u16(vsubq_u16(
+            sum, vdupq_n_u16(
+                     1 << (positive_offset_bits - kInterRoundBitsHorizontal))));
+        uint8x8_t result =
+            vqrshrun_n_s16(sum_signed, kFilterBits - kInterRoundBitsHorizontal);
+
+        // Could de-interleave and vst1_lane_u16().
+        dest8[0] = vget_lane_u8(result, 0);
+        dest8[1] = vget_lane_u8(result, 2);
+        dest8 += pred_stride;
+
+        dest8[0] = vget_lane_u8(result, 1);
+        dest8[1] = vget_lane_u8(result, 3);
+        dest8 += pred_stride;
+      }
+
+      src += src_stride << 1;
+      y += 2;
+    } while (y < height - 1);
+
+    // The 2d filters have an odd |height| because the horizontal pass generates
+    // context for the vertical pass.
+    if (is_2d) {
+      assert(height % 2 == 1);
+      uint16x8_t sum = vdupq_n_u16(1 << positive_offset_bits);
+      uint8x8_t input = vld1_u8(src);
+      if (filter_index == 3) {  // |num_taps| == 2
+        sum = vmlal_u8(sum, RightShift<3 * 8>(input), v_tap[3]);
+        sum = vmlal_u8(sum, RightShift<4 * 8>(input), v_tap[4]);
+      } else if (filter_index == 4) {
+        sum = vmlsl_u8(sum, RightShift<2 * 8>(input), v_tap[2]);
+        sum = vmlal_u8(sum, RightShift<3 * 8>(input), v_tap[3]);
+        sum = vmlal_u8(sum, RightShift<4 * 8>(input), v_tap[4]);
+        sum = vmlsl_u8(sum, RightShift<5 * 8>(input), v_tap[5]);
+      } else {
+        assert(filter_index == 5);
+        sum = vmlal_u8(sum, RightShift<2 * 8>(input), v_tap[2]);
+        sum = vmlal_u8(sum, RightShift<3 * 8>(input), v_tap[3]);
+        sum = vmlal_u8(sum, RightShift<4 * 8>(input), v_tap[4]);
+        sum = vmlal_u8(sum, RightShift<5 * 8>(input), v_tap[5]);
+        sum = vrshrq_n_u16(sum, kInterRoundBitsHorizontal);
+      }
+      Store2<0>(dest16, sum);
+    }
+  }
+}
+
 // Process 16 bit inputs and output 32 bits.
 template <int num_taps>
 uint32x4x2_t Sum2DVerticalTaps(const int16x8_t* const src,
@@ -241,175 +520,428 @@
   return return_val;
 }
 
+// Process 16 bit inputs and output 32 bits.
 template <int num_taps>
+uint32x4_t Sum2DVerticalTaps(const int16x4_t* const src, const int16x8_t taps) {
+  // In order to get the rollover correct with the lengthening instruction we
+  // need to treat these as signed so that they sign extend properly.
+  const int16x4_t taps_lo = vget_low_s16(taps);
+  const int16x4_t taps_hi = vget_high_s16(taps);
+  // An offset to guarantee the sum is non negative. Captures 56 * -4590 =
+  // 257040 (worst case negative value from horizontal pass). It should be
+  // possible to use 1 << 18 (262144) instead of 1 << 19 but there probably
+  // isn't any benefit.
+  // |offset_bits| = bitdepth + 2 * kFilterBits - kInterRoundBitsHorizontal
+  // == 19.
+  int32x4_t sum = vdupq_n_s32(1 << 19);
+  if (num_taps == 8) {
+    sum = vmlal_lane_s16(sum, src[0], taps_lo, 0);
+    sum = vmlal_lane_s16(sum, src[1], taps_lo, 1);
+    sum = vmlal_lane_s16(sum, src[2], taps_lo, 2);
+    sum = vmlal_lane_s16(sum, src[3], taps_lo, 3);
+
+    sum = vmlal_lane_s16(sum, src[4], taps_hi, 0);
+    sum = vmlal_lane_s16(sum, src[5], taps_hi, 1);
+    sum = vmlal_lane_s16(sum, src[6], taps_hi, 2);
+    sum = vmlal_lane_s16(sum, src[7], taps_hi, 3);
+  } else if (num_taps == 6) {
+    sum = vmlal_lane_s16(sum, src[0], taps_lo, 1);
+    sum = vmlal_lane_s16(sum, src[1], taps_lo, 2);
+    sum = vmlal_lane_s16(sum, src[2], taps_lo, 3);
+
+    sum = vmlal_lane_s16(sum, src[3], taps_hi, 0);
+    sum = vmlal_lane_s16(sum, src[4], taps_hi, 1);
+    sum = vmlal_lane_s16(sum, src[5], taps_hi, 2);
+  } else if (num_taps == 4) {
+    sum = vmlal_lane_s16(sum, src[0], taps_lo, 2);
+    sum = vmlal_lane_s16(sum, src[1], taps_lo, 3);
+
+    sum = vmlal_lane_s16(sum, src[2], taps_hi, 0);
+    sum = vmlal_lane_s16(sum, src[3], taps_hi, 1);
+  } else if (num_taps == 2) {
+    sum = vmlal_lane_s16(sum, src[0], taps_lo, 3);
+
+    sum = vmlal_lane_s16(sum, src[1], taps_hi, 0);
+  }
+
+  // This is guaranteed to be positive. Convert it for the final shift.
+  return vreinterpretq_u32_s32(sum);
+}
+
+template <int num_taps, bool is_compound = false>
 void Filter2DVertical(const uint16_t* src, const ptrdiff_t src_stride,
-                      uint8_t* dst, const ptrdiff_t dst_stride, const int width,
-                      const int height, const int16x8_t taps) {
+                      void* const dst, const ptrdiff_t dst_stride,
+                      const int width, const int height, const int16x8_t taps,
+                      const int inter_round_bits_vertical) {
   constexpr int next_row = num_taps - 1;
+  const int32x4_t v_inter_round_bits_vertical =
+      vdupq_n_s32(-inter_round_bits_vertical);
 
-  int x = 0;
-  do {
-    int16x8_t srcs[8];
-    srcs[0] = vreinterpretq_s16_u16(vld1q_u16(src + x));
-    if (num_taps >= 4) {
-      srcs[1] = vreinterpretq_s16_u16(vld1q_u16(src + x + src_stride));
-      srcs[2] = vreinterpretq_s16_u16(vld1q_u16(src + x + 2 * src_stride));
-      if (num_taps >= 6) {
-        srcs[3] = vreinterpretq_s16_u16(vld1q_u16(src + x + 3 * src_stride));
-        srcs[4] = vreinterpretq_s16_u16(vld1q_u16(src + x + 4 * src_stride));
-        if (num_taps == 8) {
-          srcs[5] = vreinterpretq_s16_u16(vld1q_u16(src + x + 5 * src_stride));
-          srcs[6] = vreinterpretq_s16_u16(vld1q_u16(src + x + 6 * src_stride));
-        }
-      }
-    }
+  auto* dst8 = static_cast<uint8_t*>(dst);
+  auto* dst16 = static_cast<uint16_t*>(dst);
 
-    int y = 0;
+  if (width > 4) {
+    int x = 0;
     do {
-      srcs[next_row] = vreinterpretq_s16_u16(
-          vld1q_u16(src + x + (y + next_row) * src_stride));
-
-      const uint32x4x2_t sums = Sum2DVerticalTaps<num_taps>(srcs, taps);
-      const uint16x8_t first_shift =
-          vcombine_u16(vqrshrn_n_u32(sums.val[0], kInterRoundBitsVertical),
-                       vqrshrn_n_u32(sums.val[1], kInterRoundBitsVertical));
-      // |single_round_offset| == (1 << bitdepth) + (1 << (bitdepth - 1)) ==
-      // 384
-      const uint8x8_t results =
-          vqmovn_u16(vqsubq_u16(first_shift, vdupq_n_u16(384)));
-
-      vst1_u8(dst + x + y * dst_stride, results);
-
-      srcs[0] = srcs[1];
+      int16x8_t srcs[8];
+      srcs[0] = vreinterpretq_s16_u16(vld1q_u16(src + x));
       if (num_taps >= 4) {
-        srcs[1] = srcs[2];
-        srcs[2] = srcs[3];
+        srcs[1] = vreinterpretq_s16_u16(vld1q_u16(src + x + src_stride));
+        srcs[2] = vreinterpretq_s16_u16(vld1q_u16(src + x + 2 * src_stride));
         if (num_taps >= 6) {
-          srcs[3] = srcs[4];
-          srcs[4] = srcs[5];
+          srcs[3] = vreinterpretq_s16_u16(vld1q_u16(src + x + 3 * src_stride));
+          srcs[4] = vreinterpretq_s16_u16(vld1q_u16(src + x + 4 * src_stride));
           if (num_taps == 8) {
-            srcs[5] = srcs[6];
-            srcs[6] = srcs[7];
+            srcs[5] =
+                vreinterpretq_s16_u16(vld1q_u16(src + x + 5 * src_stride));
+            srcs[6] =
+                vreinterpretq_s16_u16(vld1q_u16(src + x + 6 * src_stride));
           }
         }
       }
-    } while (++y < height);
-    x += 8;
-  } while (x < width);
+
+      int y = 0;
+      do {
+        srcs[next_row] = vreinterpretq_s16_u16(
+            vld1q_u16(src + x + (y + next_row) * src_stride));
+
+        const uint32x4x2_t sums = Sum2DVerticalTaps<num_taps>(srcs, taps);
+        if (is_compound) {
+          const uint16x8_t results = vcombine_u16(
+              vmovn_u32(vqrshlq_u32(sums.val[0], v_inter_round_bits_vertical)),
+              vmovn_u32(vqrshlq_u32(sums.val[1], v_inter_round_bits_vertical)));
+          vst1q_u16(dst16 + x + y * dst_stride, results);
+        } else {
+          const uint16x8_t first_shift =
+              vcombine_u16(vqrshrn_n_u32(sums.val[0], kInterRoundBitsVertical),
+                           vqrshrn_n_u32(sums.val[1], kInterRoundBitsVertical));
+          // |single_round_offset| == (1 << bitdepth) + (1 << (bitdepth - 1)) ==
+          // 384
+          const uint8x8_t results =
+              vqmovn_u16(vqsubq_u16(first_shift, vdupq_n_u16(384)));
+
+          vst1_u8(dst8 + x + y * dst_stride, results);
+        }
+
+        srcs[0] = srcs[1];
+        if (num_taps >= 4) {
+          srcs[1] = srcs[2];
+          srcs[2] = srcs[3];
+          if (num_taps >= 6) {
+            srcs[3] = srcs[4];
+            srcs[4] = srcs[5];
+            if (num_taps == 8) {
+              srcs[5] = srcs[6];
+              srcs[6] = srcs[7];
+            }
+          }
+        }
+      } while (++y < height);
+      x += 8;
+    } while (x < width);
+    return;
+  }
+
+  assert(width == 4);
+  int16x4_t srcs[8];
+  srcs[0] = vreinterpret_s16_u16(vld1_u16(src));
+  src += src_stride;
+  if (num_taps >= 4) {
+    srcs[1] = vreinterpret_s16_u16(vld1_u16(src));
+    src += src_stride;
+    srcs[2] = vreinterpret_s16_u16(vld1_u16(src));
+    src += src_stride;
+    if (num_taps >= 6) {
+      srcs[3] = vreinterpret_s16_u16(vld1_u16(src));
+      src += src_stride;
+      srcs[4] = vreinterpret_s16_u16(vld1_u16(src));
+      src += src_stride;
+      if (num_taps == 8) {
+        srcs[5] = vreinterpret_s16_u16(vld1_u16(src));
+        src += src_stride;
+        srcs[6] = vreinterpret_s16_u16(vld1_u16(src));
+        src += src_stride;
+      }
+    }
+  }
+
+  int y = 0;
+  do {
+    srcs[next_row] = vreinterpret_s16_u16(vld1_u16(src));
+    src += src_stride;
+
+    const uint32x4_t sums = Sum2DVerticalTaps<num_taps>(srcs, taps);
+    if (is_compound) {
+      const uint16x4_t results =
+          vmovn_u32(vqrshlq_u32(sums, v_inter_round_bits_vertical));
+      vst1_u16(dst16, results);
+      dst16 += dst_stride;
+    } else {
+      const uint16x4_t first_shift =
+          vqrshrn_n_u32(sums, kInterRoundBitsVertical);
+      // |single_round_offset| == (1 << bitdepth) + (1 << (bitdepth - 1)) ==
+      // 384
+      const uint8x8_t results = vqmovn_u16(
+          vcombine_u16(vqsub_u16(first_shift, vdup_n_u16(384)), vdup_n_u16(0)));
+
+      StoreLo4(dst8, results);
+      dst8 += dst_stride;
+    }
+
+    srcs[0] = srcs[1];
+    if (num_taps >= 4) {
+      srcs[1] = srcs[2];
+      srcs[2] = srcs[3];
+      if (num_taps >= 6) {
+        srcs[3] = srcs[4];
+        srcs[4] = srcs[5];
+        if (num_taps == 8) {
+          srcs[5] = srcs[6];
+          srcs[6] = srcs[7];
+        }
+      }
+    }
+  } while (++y < height);
+}
+
+template <bool is_2d = false, bool is_8bit = false>
+void HorizontalPass(const uint8_t* const src, const ptrdiff_t src_stride,
+                    void* const dst, const ptrdiff_t dst_stride,
+                    const int width, const int height, const int subpixel,
+                    const int filter_index) {
+  // Duplicate the absolute value for each tap.  Negative taps are corrected
+  // by using the vmlsl_u8 instruction.  Positive taps use vmlal_u8.
+  uint8x8_t v_tap[kSubPixelTaps];
+  const int filter_id = (subpixel >> 6) & kSubPixelMask;
+  for (int k = 0; k < kSubPixelTaps; ++k) {
+    v_tap[k] = vreinterpret_u8_s8(
+        vabs_s8(vdup_n_s8(kSubPixelFilters[filter_index][filter_id][k])));
+  }
+
+  if (filter_index == 2) {  // 8 tap.
+    ConvolveCompoundHorizontalBlock<8, 8, 2, true, is_2d, is_8bit>(
+        src, src_stride, dst, dst_stride, width, height, v_tap);
+  } else if (filter_index == 1) {  // 6 tap.
+    // Check if outside taps are positive.
+    if ((filter_id == 1) | (filter_id == 15)) {
+      ConvolveCompoundHorizontalBlock<6, 8, 1, false, is_2d, is_8bit>(
+          src, src_stride, dst, dst_stride, width, height, v_tap);
+    } else {
+      ConvolveCompoundHorizontalBlock<6, 8, 1, true, is_2d, is_8bit>(
+          src, src_stride, dst, dst_stride, width, height, v_tap);
+    }
+  } else if (filter_index == 0) {  // 6 tap.
+    ConvolveCompoundHorizontalBlock<6, 8, 0, true, is_2d, is_8bit>(
+        src, src_stride, dst, dst_stride, width, height, v_tap);
+  } else if (filter_index == 4) {  // 4 tap.
+    ConvolveCompoundHorizontalBlock<4, 8, 4, true, is_2d, is_8bit>(
+        src, src_stride, dst, dst_stride, width, height, v_tap);
+  } else if (filter_index == 5) {  // 4 tap.
+    ConvolveCompoundHorizontalBlock<4, 8, 5, true, is_2d, is_8bit>(
+        src, src_stride, dst, dst_stride, width, height, v_tap);
+  } else {  // 2 tap.
+    ConvolveCompoundHorizontalBlock<2, 8, 3, true, is_2d, is_8bit>(
+        src, src_stride, dst, dst_stride, width, height, v_tap);
+  }
+}
+
+// There are three forms of this function:
+// 2D: input 8bit, output 16bit. |is_compound| has no effect.
+// 1D Horizontal: input 8bit, output 8bit.
+// 1D Compound Horizontal: input 8bit, output 16bit. Different rounding from 2D.
+// |width| is guaranteed to be 2 because all other cases are handled in neon.
+template <bool is_2d = true, bool is_compound = false>
+void HorizontalPass2xH(const uint8_t* src, const ptrdiff_t src_stride,
+                       void* const dst, const ptrdiff_t dst_stride,
+                       const int height, const int filter_index, const int taps,
+                       const int subpixel) {
+  // Even though |is_compound| has no effect when |is_2d| is true we block this
+  // combination in case the compiler gets confused.
+  static_assert(!is_2d || !is_compound, "|is_compound| is ignored.");
+  // Since this only handles |width| == 2, we only need to be concerned with
+  // 2 or 4 tap filters.
+  assert(taps == 2 || taps == 4);
+  auto* dst8 = static_cast<uint8_t*>(dst);
+  auto* dst16 = static_cast<uint16_t*>(dst);
+
+  const int compound_round_offset =
+      (1 << (kBitdepth8 + 4)) + (1 << (kBitdepth8 + 3));
+
+  const int filter_id = (subpixel >> 6) & kSubPixelMask;
+  const int taps_start = (kSubPixelTaps - taps) / 2;
+  int y = 0;
+  do {
+    int x = 0;
+    do {
+      int sum;
+      if (is_2d) {
+        // An offset to guarantee the sum is non negative.
+        sum = 1 << (kBitdepth8 + kFilterBits - 1);
+      } else if (is_compound) {
+        sum = 0;
+      } else {
+        // 1D non-Compound. The C uses a two stage shift with rounding. Here the
+        // shifts are combined and the rounding bit from the first stage is
+        // added in.
+        // (sum + 4 >> 3) + 8) >> 4 == (sum + 64 + 4) >> 7
+        sum = 4;
+      }
+      for (int k = 0; k < taps; ++k) {
+        const int tap = k + taps_start;
+        sum += kSubPixelFilters[filter_index][filter_id][tap] * src[x + k];
+      }
+      if (is_2d) {
+        dst16[x] = static_cast<int16_t>(
+            RightShiftWithRounding(sum, kInterRoundBitsHorizontal));
+      } else if (is_compound) {
+        sum = RightShiftWithRounding(sum, kInterRoundBitsHorizontal);
+        dst16[x] = sum + compound_round_offset;
+      } else {
+        // 1D non-Compound.
+        dst8[x] = static_cast<uint8_t>(
+            Clip3(RightShiftWithRounding(sum, kFilterBits), 0, 255));
+      }
+    } while (++x < 2);
+
+    src += src_stride;
+    dst8 += dst_stride;
+    dst16 += dst_stride;
+  } while (++y < height);
+}
+
+// This will always need to handle all |filter_index| values. Even with |width|
+// restricted to 2 the value of |height| can go up to at least 16.
+template <bool is_2d = true, bool is_compound = false>
+void VerticalPass2xH(const void* const src, const ptrdiff_t src_stride,
+                     void* const dst, const ptrdiff_t dst_stride,
+                     const int height, const int inter_round_bits_vertical,
+                     const int filter_index, const int taps,
+                     const int subpixel) {
+  const auto* src8 = static_cast<const uint8_t*>(src);
+  const auto* src16 = static_cast<const uint16_t*>(src);
+  auto* dst8 = static_cast<uint8_t*>(dst);
+  auto* dst16 = static_cast<uint16_t*>(dst);
+  const int filter_id = (subpixel >> 6) & kSubPixelMask;
+  const int taps_start = (kSubPixelTaps - taps) / 2;
+  constexpr int max_pixel_value = (1 << kBitdepth8) - 1;
+
+  int y = 0;
+  do {
+    int x = 0;
+    do {
+      int sum;
+      if (is_2d) {
+        sum = 1 << (kBitdepth8 + 2 * kFilterBits - kInterRoundBitsHorizontal);
+      } else if (is_compound) {
+        // TODO(johannkoenig): Keeping the sum positive is valuable for neon but
+        // may not actually help the C implementation. Investigate removing
+        // this.
+        // Use this offset to cancel out 1 << (kBitdepth8 + 3) >> 3 from
+        // |compound_round_offset|.
+        sum = (1 << (kBitdepth8 + 3)) << 3;
+      } else {
+        sum = 0;
+      }
+
+      for (int k = 0; k < taps; ++k) {
+        const int tap = k + taps_start;
+        if (is_2d) {
+          sum += kSubPixelFilters[filter_index][filter_id][tap] *
+                 src16[x + k * src_stride];
+        } else {
+          sum += kSubPixelFilters[filter_index][filter_id][tap] *
+                 src8[x + k * src_stride];
+        }
+      }
+
+      if (is_2d) {
+        if (is_compound) {
+          dst16[x] = static_cast<uint16_t>(
+              RightShiftWithRounding(sum, inter_round_bits_vertical));
+        } else {
+          constexpr int single_round_offset =
+              (1 << kBitdepth8) + (1 << (kBitdepth8 - 1));
+          dst8[x] = static_cast<uint8_t>(
+              Clip3(RightShiftWithRounding(sum, kInterRoundBitsVertical) -
+                        single_round_offset,
+                    0, max_pixel_value));
+        }
+      } else if (is_compound) {
+        // Leave off + 1 << (kBitdepth8 + 3).
+        constexpr int compound_round_offset = 1 << (kBitdepth8 + 4);
+        dst16[x] = RightShiftWithRounding(sum, 3) + compound_round_offset;
+      } else {
+        // 1D non-compound.
+        dst8[x] = static_cast<uint8_t>(Clip3(
+            RightShiftWithRounding(sum, kFilterBits), 0, max_pixel_value));
+      }
+    } while (++x < 2);
+
+    src8 += src_stride;
+    src16 += src_stride;
+    dst8 += dst_stride;
+    dst16 += dst_stride;
+  } while (++y < height);
+}
+
+int NumTapsInFilter(const int filter_index) {
+  if (filter_index < 2) {
+    // Despite the names these only use 6 taps.
+    // kInterpolationFilterEightTap
+    // kInterpolationFilterEightTapSmooth
+    return 6;
+  }
+
+  if (filter_index == 2) {
+    // kInterpolationFilterEightTapSharp
+    return 8;
+  }
+
+  if (filter_index == 3) {
+    // kInterpolationFilterBilinear
+    return 2;
+  }
+
+  assert(filter_index > 3);
+  // For small sizes (width/height <= 4) the large filters are replaced with 4
+  // tap options.
+  // If the original filters were |kInterpolationFilterEightTap| or
+  // |kInterpolationFilterEightTapSharp| then it becomes
+  // |kInterpolationFilterSwitchable|.
+  // If it was |kInterpolationFilterEightTapSmooth| then it becomes an unnamed 4
+  // tap filter.
+  return 4;
 }
 
 void Convolve2D_NEON(const void* const reference,
                      const ptrdiff_t reference_stride,
                      const int horizontal_filter_index,
                      const int vertical_filter_index,
-                     const uint8_t /*inter_round_bits_vertical*/,
+                     const int /*inter_round_bits_vertical*/,
                      const int subpixel_x, const int subpixel_y,
                      const int /*step_x*/, const int /*step_y*/,
                      const int width, const int height, void* prediction,
                      const ptrdiff_t pred_stride) {
   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
-  int horizontal_taps, horizontal_taps_start, vertical_taps,
-      vertical_taps_start;
+  const int horizontal_taps = NumTapsInFilter(horiz_filter_index);
+  const int vertical_taps = NumTapsInFilter(vert_filter_index);
 
-  if (horiz_filter_index < 2) {
-    horizontal_taps = 6;
-    horizontal_taps_start = 1;
-  } else if (horiz_filter_index == 2) {
-    horizontal_taps = 8;
-    horizontal_taps_start = 0;
-  } else if (horiz_filter_index == 3) {
-    horizontal_taps = 2;
-    horizontal_taps_start = 3;
-  } else /* if (horiz_filter_index > 3) */ {
-    horizontal_taps = 4;
-    horizontal_taps_start = 2;
-  }
-
-  if (vert_filter_index < 2) {
-    vertical_taps = 6;
-    vertical_taps_start = 1;
-  } else if (vert_filter_index == 2) {
-    vertical_taps = 8;
-    vertical_taps_start = 0;
-  } else if (vert_filter_index == 3) {
-    vertical_taps = 2;
-    vertical_taps_start = 3;
-  } else /* if (vert_filter_index > 3) */ {
-    vertical_taps = 4;
-    vertical_taps_start = 2;
-  }
-
-  // Neon processes blocks of 8x8 for context during the horizontal pass so it
-  // still does a few more than it needs.
-  const int intermediate_height = height + vertical_taps - 1;
   // The output of the horizontal filter is guaranteed to fit in 16 bits.
-  uint16_t intermediate_result[kMaxSuperBlockSizeInPixels *
-                               (kMaxSuperBlockSizeInPixels + kSubPixelTaps)];
+  uint16_t
+      intermediate_result[kMaxSuperBlockSizeInPixels *
+                          (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)];
   const int intermediate_stride = width;
-  const int max_pixel_value = 255;
+  const int intermediate_height = height + vertical_taps - 1;
 
-  if (width > 4) {
-    // Horizontal filter.
-    const int horiz_filter_id = (subpixel_x >> 6) & kSubPixelMask;
-    const int16x8_t horiz_taps =
-        vld1q_s16(kSubPixelFilters[horiz_filter_index][horiz_filter_id]);
-
-    uint16_t* intermediate = intermediate_result;
+  if (width >= 4) {
     const ptrdiff_t src_stride = reference_stride;
-    // Offset for 8 tap horizontal filter and |vertical_taps|.
     const auto* src = static_cast<const uint8_t*>(reference) -
-                      ((vertical_taps / 2) - 1) * src_stride -
-                      kHorizontalOffset;
-    int y = 0;
-    do {
-      int x = 0;
-      do {
-        uint8x16_t temp[8];
-        uint8x8_t input[16];
-        for (int i = 0; i < 8; ++i) {
-          temp[i] = vld1q_u8(src + 0 + x + i * src_stride);
-        }
-        // TODO(johannkoenig): It should be possible to get the transpose
-        // started with vld2().
-        Transpose16x8(temp, input);
-        int16x8_t input16[16];
-        for (int i = 0; i < 16; ++i) {
-          input16[i] = ZeroExtend(input[i]);
-        }
+                      (vertical_taps / 2 - 1) * src_stride - kHorizontalOffset;
 
-        // TODO(johannkoenig): Explore moving the branch outside the main loop.
-        uint16x8_t output[8];
-        if (horizontal_taps == 8) {
-          for (int i = 0; i < 8; ++i) {
-            const uint16x8_t neon_sums =
-                SumTaps8To16<8>(input16 + i, horiz_taps);
-            output[i] = vrshrq_n_u16(neon_sums, kInterRoundBitsHorizontal);
-          }
-        } else if (horizontal_taps == 6) {
-          for (int i = 0; i < 8; ++i) {
-            const uint16x8_t neon_sums =
-                SumTaps8To16<6>(input16 + i + 1, horiz_taps);
-            output[i] = vrshrq_n_u16(neon_sums, kInterRoundBitsHorizontal);
-          }
-        } else {  // |horizontal_taps| == 2
-          for (int i = 0; i < 8; ++i) {
-            const uint16x8_t neon_sums =
-                SumTaps8To16<2>(input16 + i + 3, horiz_taps);
-            output[i] = vrshrq_n_u16(neon_sums, kInterRoundBitsHorizontal);
-          }
-        }
-
-        Transpose8x8(output);
-        for (int i = 0; i < 8; ++i) {
-          vst1q_u16(intermediate + x + i * intermediate_stride, output[i]);
-        }
-        x += 8;
-      } while (x < width);
-      src += src_stride << 3;
-      intermediate += intermediate_stride << 3;
-      y += 8;
-    } while (y < intermediate_height);
+    HorizontalPass<true>(src, src_stride, intermediate_result,
+                         intermediate_stride, width, intermediate_height,
+                         subpixel_x, horiz_filter_index);
 
     // Vertical filter.
     auto* dest = static_cast<uint8_t*>(prediction);
@@ -420,63 +952,34 @@
 
     if (vertical_taps == 8) {
       Filter2DVertical<8>(intermediate_result, intermediate_stride, dest,
-                          dest_stride, width, height, taps);
+                          dest_stride, width, height, taps, 0);
     } else if (vertical_taps == 6) {
       Filter2DVertical<6>(intermediate_result, intermediate_stride, dest,
-                          dest_stride, width, height, taps);
+                          dest_stride, width, height, taps, 0);
     } else if (vertical_taps == 4) {
       Filter2DVertical<4>(intermediate_result, intermediate_stride, dest,
-                          dest_stride, width, height, taps);
+                          dest_stride, width, height, taps, 0);
     } else {  // |vertical_taps| == 2
       Filter2DVertical<2>(intermediate_result, intermediate_stride, dest,
-                          dest_stride, width, height, taps);
+                          dest_stride, width, height, taps, 0);
     }
   } else {
+    assert(width == 2);
     // Horizontal filter.
-    // Filter types used for width <= 4 are different from those for width > 4.
-    // When width > 4, the valid filter index range is always [0, 3].
-    // When width <= 4, the valid filter index range is always [4, 5].
-    // Similarly for height.
-    uint16_t* intermediate = intermediate_result;
-    const ptrdiff_t src_stride = reference_stride;
-    const auto* src = static_cast<const uint8_t*>(reference) -
-                      ((vertical_taps / 2) - 1) * src_stride -
-                      ((horizontal_taps / 2) - 1);
+    const auto* const src = static_cast<const uint8_t*>(reference) -
+                            ((vertical_taps / 2) - 1) * reference_stride -
+                            ((horizontal_taps / 2) - 1);
+
+    HorizontalPass2xH(src, reference_stride, intermediate_result,
+                      intermediate_stride, intermediate_height,
+                      horiz_filter_index, horizontal_taps, subpixel_x);
+
+    // Vertical filter.
     auto* dest = static_cast<uint8_t*>(prediction);
     const ptrdiff_t dest_stride = pred_stride;
-    int filter_id = (subpixel_x >> 6) & kSubPixelMask;
-    for (int y = 0; y < intermediate_height; ++y) {
-      for (int x = 0; x < width; ++x) {
-        // An offset to guarantee the sum is non negative.
-        int sum = 1 << 14;
-        for (int k = 0; k < horizontal_taps; ++k) {
-          const int tap = k + horizontal_taps_start;
-          sum +=
-              kSubPixelFilters[horiz_filter_index][filter_id][tap] * src[x + k];
-        }
-        intermediate[x] = static_cast<int16_t>(RightShiftWithRounding(sum, 3));
-      }
-      src += src_stride;
-      intermediate += intermediate_stride;
-    }
-    // Vertical filter.
-    intermediate = intermediate_result;
-    filter_id = ((subpixel_y & 1023) >> 6) & kSubPixelMask;
-    for (int y = 0; y < height; ++y) {
-      for (int x = 0; x < width; ++x) {
-        // An offset to guarantee the sum is non negative.
-        int sum = 1 << 19;
-        for (int k = 0; k < vertical_taps; ++k) {
-          const int tap = k + vertical_taps_start;
-          sum += kSubPixelFilters[vert_filter_index][filter_id][tap] *
-                 intermediate[k * intermediate_stride + x];
-        }
-        dest[x] = static_cast<uint8_t>(
-            Clip3(RightShiftWithRounding(sum, 11) - 384, 0, max_pixel_value));
-      }
-      dest += dest_stride;
-      intermediate += intermediate_stride;
-    }
+
+    VerticalPass2xH(intermediate_result, intermediate_stride, dest, dest_stride,
+                    height, 0, vert_filter_index, vertical_taps, subpixel_y);
   }
 }
 
@@ -568,6 +1071,10 @@
     int16x8_t s[(grade_x + 1) * 8];
     const uint8_t* src_x =
         &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
+    // TODO(petersonab,b/139707209): Fix source buffer overreads.
+    // For example, when |height| == 2 and |num_taps| == 8 then
+    // |intermediate_height| == 9. On the second pass this will load and
+    // transpose 7 rows past where |src| may end.
     Load8x8(src_x, src_stride, s);
     Transpose8x8(s);
     if (grade_x > 1) {
@@ -843,7 +1350,7 @@
 void ConvolveCompoundScale2D_NEON(
     const void* const reference, const ptrdiff_t reference_stride,
     const int horizontal_filter_index, const int vertical_filter_index,
-    const uint8_t inter_round_bits_vertical, const int subpixel_x,
+    const int inter_round_bits_vertical, const int subpixel_x,
     const int subpixel_y, const int step_x, const int step_y, const int width,
     const int height, void* prediction, const ptrdiff_t pred_stride) {
   const int intermediate_height =
@@ -944,7 +1451,7 @@
                              const ptrdiff_t reference_stride,
                              const int horizontal_filter_index,
                              const int /*vertical_filter_index*/,
-                             const uint8_t /*inter_round_bits_vertical*/,
+                             const int /*inter_round_bits_vertical*/,
                              const int subpixel_x, const int /*subpixel_y*/,
                              const int /*step_x*/, const int /*step_y*/,
                              const int width, const int height,
@@ -953,101 +1460,10 @@
   const int filter_index = GetFilterIndex(horizontal_filter_index, width);
   // Set |src| to the outermost tap.
   const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset;
-  const ptrdiff_t src_stride = reference_stride;
   auto* dest = static_cast<uint8_t*>(prediction);
-  const ptrdiff_t dest_stride = pred_stride;
-  const int filter_id = (subpixel_x >> 6) & kSubPixelMask;
-  const int block_output_height = std::min(height, 8);
-  const int16x8_t four = vdupq_n_s16(4);
 
-  int16x8_t taps;
-  if (filter_index < 3) {
-    // 6 and 8 tap filters.
-    taps = vld1q_s16(kSubPixelFilters[filter_index][filter_id]);
-  } else {
-    // The 2 tap filter only uses the lower half of |taps|.
-    taps = vld1q_s16(kSubPixelFilters[filter_index][filter_id] + 2);
-  }
-
-  // TODO(johannkoenig): specialize small |height| variants so we don't
-  // overread |reference|.
-  if (width > 4 && height > 4) {
-    int y = 0;
-    do {
-      // This was intended to load and transpose 16 values before the |width|
-      // loop. At the end of the loop it would keep 8 of those values and only
-      // load and transpose 8 additional values. Unfortunately the approach did
-      // not appear to provide any benefit.
-      int x = 0;
-      do {
-        uint8x16_t temp[8];
-        uint8x8_t input[16];
-        for (int i = 0; i < 8; ++i) {
-          temp[i] = vld1q_u8(src + x + i * src_stride);
-        }
-        // TODO(johannkoenig): It should be possible to get the transpose
-        // started with vld4().
-        Transpose16x8(temp, input);
-        int16x8_t input16[16];
-        for (int i = 0; i < 16; ++i) {
-          input16[i] = ZeroExtend(input[i]);
-        }
-
-        // This does not handle |filter_index| > 3 because those 4 tap filters
-        // are only used when |width| <= 4.
-        // TODO(johannkoenig): Explore moving the branch outside the main loop.
-        uint8x8_t output[8];
-        if (filter_index == 2) {  // 8 taps.
-          for (int i = 0; i < 8; ++i) {
-            const int16x8_t neon_sums = SumTaps<8>(input16 + i, taps);
-            output[i] =
-                vqrshrun_n_s16(vqaddq_s16(neon_sums, four), kFilterBits);
-          }
-        } else if (filter_index < 2) {  // 6 taps.
-          for (int i = 0; i < 8; ++i) {
-            const int16x8_t neon_sums = SumTaps<6>(input16 + i + 1, taps);
-            output[i] =
-                vqrshrun_n_s16(vqaddq_s16(neon_sums, four), kFilterBits);
-          }
-        } else {  // |filter_index| == 3. 2 taps.
-          for (int i = 0; i < 8; ++i) {
-            const int16x8_t neon_sums = SumTaps<2>(input16 + i + 3, taps);
-            output[i] =
-                vqrshrun_n_s16(vqaddq_s16(neon_sums, four), kFilterBits);
-          }
-        }
-
-        Transpose8x8(output);
-
-        int i = 0;
-        do {
-          vst1_u8(dest + x + i * dest_stride, output[i]);
-        } while (++i < block_output_height);
-        x += 8;
-      } while (x < width);
-      y += 8;
-      src += 8 * src_stride;
-      dest += 8 * dest_stride;
-    } while (y < height);
-  } else {
-    // TODO(johannkoenig): Investigate 2xH and 4xH. During the original
-    // implementation 4x2 was slower than C, 4x4 reached parity, and 4x8
-    // was < 20% faster.
-    for (int y = 0; y < height; ++y) {
-      for (int x = 0; x < width; ++x) {
-        int sum = 0;
-        for (int k = 0; k < kSubPixelTaps; ++k) {
-          sum += kSubPixelFilters[filter_index][filter_id][k] * src[x + k];
-        }
-        // We can combine the shifts if we compensate for the skipped rounding.
-        // ((sum + 4 >> 3) + 8) >> 4 == (sum + 64 + 4) >> 7;
-        dest[x] = static_cast<uint8_t>(
-            Clip3(RightShiftWithRounding(sum + 4, kFilterBits), 0, 255));
-      }
-      src += src_stride;
-      dest += dest_stride;
-    }
-  }
+  HorizontalPass<false, true>(src, reference_stride, dest, pred_stride, width,
+                              height, subpixel_x, filter_index);
 }
 
 template <int min_width, int num_taps>
@@ -1123,7 +1539,7 @@
                            const ptrdiff_t reference_stride,
                            const int /*horizontal_filter_index*/,
                            const int vertical_filter_index,
-                           const uint8_t /*inter_round_bits_vertical*/,
+                           const int /*inter_round_bits_vertical*/,
                            const int /*subpixel_x*/, const int subpixel_y,
                            const int /*step_x*/, const int /*step_y*/,
                            const int width, const int height, void* prediction,
@@ -1201,31 +1617,19 @@
       }
     }
   } else {
-    // TODO(johannkoenig): Determine if it is worth writing a 2xH
-    // implementation.
     assert(width == 2);
-    const int max_pixel_value = 255;
-    int y = 0;
-    do {
-      for (int x = 0; x < 2; ++x) {
-        int sum = 0;
-        for (int k = 0; k < kSubPixelTaps; ++k) {
-          sum += kSubPixelFilters[filter_index][filter_id][k] *
-                 src[k * src_stride + x];
-        }
-        dest[x] = static_cast<uint8_t>(Clip3(
-            RightShiftWithRounding(sum, kFilterBits), 0, max_pixel_value));
-      }
-      src += src_stride;
-      dest += pred_stride;
-    } while (++y < height);
+    const int taps = NumTapsInFilter(filter_index);
+    src =
+        static_cast<const uint8_t*>(reference) - ((taps / 2) - 1) * src_stride;
+    VerticalPass2xH</*is_2d=*/false>(src, src_stride, dest, pred_stride, height,
+                                     0, filter_index, taps, subpixel_y);
   }
 }
 
 void ConvolveCompoundCopy_NEON(
     const void* const reference, const ptrdiff_t reference_stride,
     const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
-    const uint8_t /*inter_round_bits_vertical*/, const int /*subpixel_x*/,
+    const int /*inter_round_bits_vertical*/, const int /*subpixel_x*/,
     const int /*subpixel_y*/, const int /*step_x*/, const int /*step_y*/,
     const int width, const int height, void* prediction,
     const ptrdiff_t pred_stride) {
@@ -1339,9 +1743,10 @@
       // the end. Instead we use that to compensate for the initial offset.
       // (1 << (bitdepth + 4)) + (1 << (bitdepth + 3)) == (1 << 12) + (1 << 11)
       // After taking into account the shift above:
-      // RightShiftWithRounding(LeftShift(sum, bits_shift), inter_round_bits[1])
-      // where bits_shift == kFilterBits - inter_round_bits[0] == 4
-      // and inter_round_bits[1] == 7
+      // RightShiftWithRounding(LeftShift(sum, bits_shift),
+      //                        inter_round_bits_vertical)
+      // where bits_shift == kFilterBits - kInterRoundBitsHorizontal == 4
+      // and inter_round_bits_vertical == 7
       // and simplifying it to RightShiftWithRounding(sum, 3)
       // we see that the initial offset of 1 << 14 >> 3 == 1 << 11 and
       // |compound_round_offset| can be simplified to 1 << 12.
@@ -1374,7 +1779,7 @@
 void ConvolveCompoundVertical_NEON(
     const void* const reference, const ptrdiff_t reference_stride,
     const int /*horizontal_filter_index*/, const int vertical_filter_index,
-    const uint8_t /*inter_round_bits_vertical*/, const int /*subpixel_x*/,
+    const int /*inter_round_bits_vertical*/, const int /*subpixel_x*/,
     const int subpixel_y, const int /*step_x*/, const int /*step_y*/,
     const int width, const int height, void* prediction,
     const ptrdiff_t pred_stride) {
@@ -1384,7 +1789,6 @@
       static_cast<const uint8_t*>(reference) - kVerticalOffset * src_stride;
   auto* dest = static_cast<uint16_t*>(prediction);
   const int filter_id = (subpixel_y >> 6) & kSubPixelMask;
-  const int compound_round_offset = 1 << 12;  // Leave off + 1 << 11.
 
   if (width >= 4) {
     const int16x8_t taps = vld1q_s16(kSubPixelFilters[filter_index][filter_id]);
@@ -1424,236 +1828,108 @@
     }
   } else {
     assert(width == 2);
-    for (int y = 0; y < height; ++y) {
-      for (int x = 0; x < 2; ++x) {
-        // Use an offset to avoid 32 bits.
-        int sum = 1 << 14;
-        for (int k = 0; k < kSubPixelTaps; ++k) {
-          sum += kSubPixelFilters[filter_index][filter_id][k] *
-                 src[k * src_stride + x];
-        }
-        // |compound_round_offset| has been modified to take into account the
-        // offset used above. The 1 << 11 term cancels out with 1 << 14 >> 3.
-        dest[x] = RightShiftWithRounding(sum, 3) + compound_round_offset;
-      }
-      src += src_stride;
-      dest += pred_stride;
-    }
-  }
-}
-
-template <int num_taps, int filter_index, bool negative_outside_taps = true>
-uint16x8_t SumCompoundHorizontalTaps(const uint8_t* const src,
-                                     uint8x8_t* v_tap) {
-  // Start with an offset to guarantee the sum is non negative.
-  uint16x8_t v_sum = vdupq_n_u16(1 << 14);
-  uint8x16_t v_src[8];
-  v_src[0] = vld1q_u8(&src[0]);
-  if (num_taps == 8) {
-    v_src[1] = vextq_u8(v_src[0], v_src[0], 1);
-    v_src[2] = vextq_u8(v_src[0], v_src[0], 2);
-    v_src[3] = vextq_u8(v_src[0], v_src[0], 3);
-    v_src[4] = vextq_u8(v_src[0], v_src[0], 4);
-    v_src[5] = vextq_u8(v_src[0], v_src[0], 5);
-    v_src[6] = vextq_u8(v_src[0], v_src[0], 6);
-    v_src[7] = vextq_u8(v_src[0], v_src[0], 7);
-
-    // tap signs : - + - + + - + -
-    v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[0]), v_tap[0]);
-    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[1]), v_tap[1]);
-    v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
-    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
-    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
-    v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
-    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[6]), v_tap[6]);
-    v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[7]), v_tap[7]);
-  } else if (num_taps == 6) {
-    v_src[1] = vextq_u8(v_src[0], v_src[0], 1);
-    v_src[2] = vextq_u8(v_src[0], v_src[0], 2);
-    v_src[3] = vextq_u8(v_src[0], v_src[0], 3);
-    v_src[4] = vextq_u8(v_src[0], v_src[0], 4);
-    v_src[5] = vextq_u8(v_src[0], v_src[0], 5);
-    v_src[6] = vextq_u8(v_src[0], v_src[0], 6);
-    if (filter_index == 0) {
-      // tap signs : + - + + - +
-      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[1]), v_tap[1]);
-      v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
-      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
-      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
-      v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
-      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[6]), v_tap[6]);
-    } else {
-      if (negative_outside_taps) {
-        // tap signs : - + + + + -
-        v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[1]), v_tap[1]);
-        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
-        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
-        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
-        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
-        v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[6]), v_tap[6]);
-      } else {
-        // tap signs : + + + + + +
-        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[1]), v_tap[1]);
-        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
-        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
-        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
-        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
-        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[6]), v_tap[6]);
-      }
-    }
-  } else if (num_taps == 4) {
-    v_src[2] = vextq_u8(v_src[0], v_src[0], 2);
-    v_src[3] = vextq_u8(v_src[0], v_src[0], 3);
-    v_src[4] = vextq_u8(v_src[0], v_src[0], 4);
-    v_src[5] = vextq_u8(v_src[0], v_src[0], 5);
-    if (filter_index == 4) {
-      // tap signs : - + + -
-      v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
-      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
-      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
-      v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
-    } else {
-      // tap signs : + + + +
-      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
-      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
-      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
-      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
-    }
-  } else {
-    assert(num_taps == 2);
-    v_src[3] = vextq_u8(v_src[0], v_src[0], 3);
-    v_src[4] = vextq_u8(v_src[0], v_src[0], 4);
-    // tap signs : + +
-    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
-    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
-  }
-
-  return v_sum;
-}
-
-template <int num_taps, int step, int filter_index,
-          bool negative_outside_taps = true>
-void ConvolveCompoundHorizontalBlock(const uint8_t* src, ptrdiff_t src_stride,
-                                     uint16_t* dest, ptrdiff_t pred_stride,
-                                     const int width, const int height,
-                                     uint8x8_t* v_tap,
-                                     int16x8_t v_inter_round_bits_0,
-                                     int16x8_t v_bits_shift,
-                                     uint16x8_t v_compound_round_offset) {
-  if (width > 4) {
-    int y = 0;
-    do {
-      int x = 0;
-      do {
-        uint16x8_t v_sum =
-            SumCompoundHorizontalTaps<num_taps, filter_index,
-                                      negative_outside_taps>(&src[x], v_tap);
-        v_sum = vrshlq_u16(v_sum, v_inter_round_bits_0);
-        v_sum = vshlq_u16(v_sum, v_bits_shift);
-        v_sum = vaddq_u16(v_sum, v_compound_round_offset);
-        vst1q_u16(&dest[x], v_sum);
-        x += step;
-      } while (x < width);
-      src += src_stride;
-      dest += pred_stride;
-    } while (++y < height);
-  } else {
-    int y = 0;
-    do {
-      uint16x8_t v_sum =
-          SumCompoundHorizontalTaps<num_taps, filter_index,
-                                    negative_outside_taps>(&src[0], v_tap);
-      v_sum = vrshlq_u16(v_sum, v_inter_round_bits_0);
-      v_sum = vshlq_u16(v_sum, v_bits_shift);
-      v_sum = vaddq_u16(v_sum, v_compound_round_offset);
-      vst1_u16(&dest[0], vget_low_u16(v_sum));
-      src += src_stride;
-      dest += pred_stride;
-    } while (++y < height);
+    const int taps = NumTapsInFilter(filter_index);
+    src =
+        static_cast<const uint8_t*>(reference) - ((taps / 2) - 1) * src_stride;
+    VerticalPass2xH</*is_2d=*/false, /*is_compound=*/true>(
+        src, src_stride, dest, pred_stride, height, 0, filter_index, taps,
+        subpixel_y);
   }
 }
 
 void ConvolveCompoundHorizontal_NEON(
     const void* const reference, const ptrdiff_t reference_stride,
     const int horizontal_filter_index, const int /*vertical_filter_index*/,
-    const uint8_t inter_round_bits_vertical, const int subpixel_x,
+    const int /*inter_round_bits_vertical*/, const int subpixel_x,
     const int /*subpixel_y*/, const int /*step_x*/, const int /*step_y*/,
     const int width, const int height, void* prediction,
     const ptrdiff_t pred_stride) {
   const int filter_index = GetFilterIndex(horizontal_filter_index, width);
   const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset;
-  const ptrdiff_t src_stride = reference_stride;
   auto* dest = static_cast<uint16_t*>(prediction);
-  const int filter_id = (subpixel_x >> 6) & kSubPixelMask;
-  const int bits_shift = kFilterBits - inter_round_bits_vertical;
 
-  const int compound_round_offset =
-      (1 << (kBitdepth8 + 4)) + (1 << (kBitdepth8 + 3));
+  HorizontalPass(src, reference_stride, dest, pred_stride, width, height,
+                 subpixel_x, filter_index);
+}
+
+void ConvolveCompound2D_NEON(const void* const reference,
+                             const ptrdiff_t reference_stride,
+                             const int horizontal_filter_index,
+                             const int vertical_filter_index,
+                             const int inter_round_bits_vertical,
+                             const int subpixel_x, const int subpixel_y,
+                             const int /*step_x*/, const int /*step_y*/,
+                             const int width, const int height,
+                             void* prediction, const ptrdiff_t pred_stride) {
+  // The output of the horizontal filter, i.e. the intermediate_result, is
+  // guaranteed to fit in int16_t.
+  uint16_t
+      intermediate_result[kMaxSuperBlockSizeInPixels *
+                          (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)];
+  const int intermediate_stride = kMaxSuperBlockSizeInPixels;
+
+  // Horizontal filter.
+  // Filter types used for width <= 4 are different from those for width > 4.
+  // When width > 4, the valid filter index range is always [0, 3].
+  // When width <= 4, the valid filter index range is always [4, 5].
+  // Similarly for height.
+  const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
+  const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
+  const int horizontal_taps = NumTapsInFilter(horiz_filter_index);
+  const int vertical_taps = NumTapsInFilter(vert_filter_index);
+  uint16_t* intermediate = intermediate_result;
+  const int intermediate_height = height + vertical_taps - 1;
+  const ptrdiff_t src_stride = reference_stride;
+  const auto* src = static_cast<const uint8_t*>(reference) -
+                    kVerticalOffset * src_stride - kHorizontalOffset;
+  auto* dest = static_cast<uint16_t*>(prediction);
+  int filter_id = (subpixel_x >> 6) & kSubPixelMask;
 
   if (width >= 4) {
-    // Duplicate the absolute value for each tap.  Negative taps are corrected
-    // by using the vmlsl_u8 instruction.  Positive taps use vmlal_u8.
-    uint8x8_t v_tap[kSubPixelTaps];
-    for (int k = 0; k < kSubPixelTaps; ++k) {
-      v_tap[k] = vreinterpret_u8_s8(
-          vabs_s8(vdup_n_s8(kSubPixelFilters[filter_index][filter_id][k])));
-    }
+    // TODO(johannkoenig): Use |width| for |intermediate_stride|.
+    src = static_cast<const uint8_t*>(reference) -
+          (vertical_taps / 2 - 1) * src_stride - kHorizontalOffset;
+    HorizontalPass<true>(src, src_stride, intermediate_result,
+                         intermediate_stride, width, intermediate_height,
+                         subpixel_x, horiz_filter_index);
 
-    const int16x8_t v_inter_round_bits_0 =
-        vdupq_n_s16(-kInterRoundBitsHorizontal);
-    const int16x8_t v_bits_shift = vdupq_n_s16(bits_shift);
+    // Vertical filter.
+    intermediate = intermediate_result;
+    filter_id = ((subpixel_y & 1023) >> 6) & kSubPixelMask;
 
-    const uint16x8_t v_compound_round_offset =
-        vdupq_n_u16(compound_round_offset - (1 << (kBitdepth8 + 3)));
+    const ptrdiff_t dest_stride = pred_stride;
+    const int16x8_t taps =
+        vld1q_s16(kSubPixelFilters[vert_filter_index][filter_id]);
 
-    if (filter_index == 2) {  // 8 tap.
-      ConvolveCompoundHorizontalBlock<8, 8, 2>(
-          src, src_stride, dest, pred_stride, width, height, v_tap,
-          v_inter_round_bits_0, v_bits_shift, v_compound_round_offset);
-    } else if (filter_index == 1) {  // 6 tap.
-      // Check if outside taps are positive.
-      if ((filter_id == 1) | (filter_id == 15)) {
-        ConvolveCompoundHorizontalBlock<6, 8, 1, false>(
-            src, src_stride, dest, pred_stride, width, height, v_tap,
-            v_inter_round_bits_0, v_bits_shift, v_compound_round_offset);
-      } else {
-        ConvolveCompoundHorizontalBlock<6, 8, 1>(
-            src, src_stride, dest, pred_stride, width, height, v_tap,
-            v_inter_round_bits_0, v_bits_shift, v_compound_round_offset);
-      }
-    } else if (filter_index == 0) {  // 6 tap.
-      ConvolveCompoundHorizontalBlock<6, 8, 0>(
-          src, src_stride, dest, pred_stride, width, height, v_tap,
-          v_inter_round_bits_0, v_bits_shift, v_compound_round_offset);
-    } else if (filter_index == 4) {  // 4 tap.
-      ConvolveCompoundHorizontalBlock<4, 8, 4>(
-          src, src_stride, dest, pred_stride, width, height, v_tap,
-          v_inter_round_bits_0, v_bits_shift, v_compound_round_offset);
-    } else if (filter_index == 5) {  // 4 tap.
-      ConvolveCompoundHorizontalBlock<4, 8, 5>(
-          src, src_stride, dest, pred_stride, width, height, v_tap,
-          v_inter_round_bits_0, v_bits_shift, v_compound_round_offset);
-    } else {  // 2 tap.
-      ConvolveCompoundHorizontalBlock<2, 8, 3>(
-          src, src_stride, dest, pred_stride, width, height, v_tap,
-          v_inter_round_bits_0, v_bits_shift, v_compound_round_offset);
+    if (vertical_taps == 8) {
+      Filter2DVertical<8, /*is_compound=*/true>(
+          intermediate, intermediate_stride, dest, dest_stride, width, height,
+          taps, inter_round_bits_vertical);
+    } else if (vertical_taps == 6) {
+      Filter2DVertical<6, /*is_compound=*/true>(
+          intermediate, intermediate_stride, dest, dest_stride, width, height,
+          taps, inter_round_bits_vertical);
+    } else if (vertical_taps == 4) {
+      Filter2DVertical<4, /*is_compound=*/true>(
+          intermediate, intermediate_stride, dest, dest_stride, width, height,
+          taps, inter_round_bits_vertical);
+    } else {  // |vertical_taps| == 2
+      Filter2DVertical<2, /*is_compound=*/true>(
+          intermediate, intermediate_stride, dest, dest_stride, width, height,
+          taps, inter_round_bits_vertical);
     }
   } else {
-    // 2xH
-    int y = 0;
-    do {
-      for (int x = 0; x < 2; ++x) {
-        int sum = 0;
-        for (int k = 0; k < kSubPixelTaps; ++k) {
-          sum += kSubPixelFilters[filter_index][filter_id][k] * src[x + k];
-        }
-        sum = RightShiftWithRounding(sum, kInterRoundBitsHorizontal)
-              << bits_shift;
-        dest[x] = sum + compound_round_offset;
-      }
-      src += src_stride;
-      dest += pred_stride;
-    } while (++y < height);
+    src = static_cast<const uint8_t*>(reference) -
+          ((vertical_taps / 2) - 1) * src_stride - ((horizontal_taps / 2) - 1);
+
+    HorizontalPass2xH(src, src_stride, intermediate_result, intermediate_stride,
+                      intermediate_height, horiz_filter_index, horizontal_taps,
+                      subpixel_x);
+
+    VerticalPass2xH</*is_2d=*/true, /*is_compound=*/true>(
+        intermediate_result, intermediate_stride, dest, pred_stride, height,
+        inter_round_bits_vertical, vert_filter_index, vertical_taps,
+        subpixel_y);
   }
 }
 
@@ -1662,14 +1938,14 @@
   assert(dsp != nullptr);
   dsp->convolve[0][0][0][1] = ConvolveHorizontal_NEON;
   dsp->convolve[0][0][1][0] = ConvolveVertical_NEON;
-  // TODO(b/139707209): reenable after segfault on android is fixed.
-  // dsp->convolve[0][0][1][1] = Convolve2D_NEON;
-  static_cast<void>(Convolve2D_NEON);
+  dsp->convolve[0][0][1][1] = Convolve2D_NEON;
 
   dsp->convolve[0][1][0][0] = ConvolveCompoundCopy_NEON;
   dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_NEON;
   dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_NEON;
+  dsp->convolve[0][1][1][1] = ConvolveCompound2D_NEON;
 
+  // TODO(petersonab,b/139707209): Fix source buffer overreads.
   // dsp->convolve_scale[1] = ConvolveCompoundScale2D_NEON;
   static_cast<void>(ConvolveCompoundScale2D_NEON);
 }
diff --git a/libgav1/src/dsp/arm/convolve_neon.h b/libgav1/src/dsp/arm/convolve_neon.h
index 6b5873c..7eae136 100644
--- a/libgav1/src/dsp/arm/convolve_neon.h
+++ b/libgav1/src/dsp/arm/convolve_neon.h
@@ -16,13 +16,14 @@
 #if LIBGAV1_ENABLE_NEON
 #define LIBGAV1_Dsp8bpp_ConvolveHorizontal LIBGAV1_DSP_NEON
 #define LIBGAV1_Dsp8bpp_ConvolveVertical LIBGAV1_DSP_NEON
-// TODO(b/139707209): reenable after segfault on android is fixed.
-// #define LIBGAV1_Dsp8bpp_Convolve2D LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_Convolve2D LIBGAV1_DSP_NEON
 
 #define LIBGAV1_Dsp8bpp_ConvolveCompoundCopy LIBGAV1_DSP_NEON
 #define LIBGAV1_Dsp8bpp_ConvolveCompoundHorizontal LIBGAV1_DSP_NEON
 #define LIBGAV1_Dsp8bpp_ConvolveCompoundVertical LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveCompound2D LIBGAV1_DSP_NEON
 
+// TODO(petersonab,b/139707209): Fix source buffer overreads.
 // #define LIBGAV1_Dsp8bpp_ConvolveCompoundScale2D LIBGAV1_DSP_NEON
 #endif  // LIBGAV1_ENABLE_NEON
 
diff --git a/libgav1/src/dsp/arm/intrapred_directional_neon.cc b/libgav1/src/dsp/arm/intrapred_directional_neon.cc
index 751ba89..0684fa5 100644
--- a/libgav1/src/dsp/arm/intrapred_directional_neon.cc
+++ b/libgav1/src/dsp/arm/intrapred_directional_neon.cc
@@ -234,21 +234,14 @@
     int y = 0;
     do {
       const int top_base_x = top_x >> 6;
-
-      if (top_base_x >= max_base_x) {
-        for (int i = y; i < height; ++i) {
-          memset(dst, top[max_base_x], width);
-          dst += stride;
-        }
-        return;
-      }
-
       const uint8_t shift = ((top_x << upsample_shift) & 0x3F) >> 1;
-
       uint8x8_t base_v = vadd_u8(vdup_n_u8(top_base_x), all);
-
       int x = 0;
-      do {
+      // Only calculate a block of 8 when at least one of the output values is
+      // within range. Otherwise it can read off the end of |top|.
+      const int must_calculate_width =
+          std::min(width, max_base_x - top_base_x + 7) & ~7;
+      for (; x < must_calculate_width; x += 8) {
         const uint8x8_t max_base_mask = vclt_u8(base_v, max_base);
 
         // Since these |xstep| values can not be upsampled the load is
@@ -260,10 +253,9 @@
             vbsl_u8(max_base_mask, value, top_max_base);
 
         vst1_u8(dst + x, masked_value);
-
         base_v = vadd_u8(base_v, block_step);
-        x += 8;
-      } while (x < width);
+      }
+      memset(dst + x, top[max_base_x], width - x);
       dst += stride;
       top_x += xstep;
     } while (++y < height);
diff --git a/libgav1/src/dsp/arm/inverse_transform_neon.cc b/libgav1/src/dsp/arm/inverse_transform_neon.cc
index 9e57365..b251639 100644
--- a/libgav1/src/dsp/arm/inverse_transform_neon.cc
+++ b/libgav1/src/dsp/arm/inverse_transform_neon.cc
@@ -425,6 +425,38 @@
 //------------------------------------------------------------------------------
 // Discrete Cosine Transforms (DCT).
 
+template <int width>
+LIBGAV1_ALWAYS_INLINE bool DctDcOnly(void* dest, const void* source,
+                                     int non_zero_coeff_count,
+                                     bool should_round, int row_shift) {
+  if (non_zero_coeff_count > 1) {
+    return false;
+  }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+
+  const int16x8_t v_src = vdupq_n_s16(src[0]);
+  const uint16x8_t v_mask = vdupq_n_u16(should_round ? 0xffff : 0);
+  const int16x8_t v_src_round =
+      vqrdmulhq_n_s16(v_src, kTransformRowMultiplier << 3);
+  const int16x8_t s0 = vbslq_s16(v_mask, v_src_round, v_src);
+  const int16_t cos128 = Cos128(32);
+  const int16x8_t xy = vqrdmulhq_s16(s0, vdupq_n_s16(cos128 << 3));
+  // vqrshlq_s16 will shift right if shift value is negative.
+  const int16x8_t xy_shifted = vqrshlq_s16(xy, vdupq_n_s16(-row_shift));
+
+  if (width == 4) {
+    vst1_s16(dst, vget_low_s16(xy_shifted));
+  } else {
+    for (int i = 0; i < width; i += 8) {
+      vst1q_s16(dst, xy_shifted);
+      dst += 8;
+    }
+  }
+  return true;
+}
+
 template <ButterflyRotationFunc bufferfly_rotation,
           bool is_fast_bufferfly = false>
 LIBGAV1_ALWAYS_INLINE void Dct4Stages(int16x8_t* s) {
@@ -1982,8 +2014,15 @@
   const int tx_height = kTransformHeight[tx_size];
 
   if (is_row) {
-    const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
     const bool should_round = (tx_height == 8);
+    const int row_shift = (tx_height == 16);
+
+    if (DctDcOnly<4>(&src[0], &src[0], non_zero_coeff_count, should_round,
+                     row_shift)) {
+      return;
+    }
+
+    const int num_rows = tx_height;
     if (should_round) {
       ApplyRounding<4>(src, num_rows);
     }
@@ -2038,8 +2077,16 @@
   const int tx_height = kTransformHeight[tx_size];
 
   if (is_row) {
-    const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
-    if (kShouldRound[tx_size]) {
+    const bool should_round = kShouldRound[tx_size];
+    const uint8_t row_shift = kTransformRowShift[tx_size];
+
+    if (DctDcOnly<8>(&src[0], &src[0], non_zero_coeff_count, should_round,
+                     row_shift)) {
+      return;
+    }
+
+    const int num_rows = tx_height;
+    if (should_round) {
       ApplyRounding<8>(src, num_rows);
     }
 
@@ -2056,7 +2103,6 @@
         i += 8;
       } while (i < num_rows);
     }
-    const uint8_t row_shift = kTransformRowShift[tx_size];
     if (row_shift > 0) {
       RowShift<8>(src, num_rows, row_shift);
     }
@@ -2094,9 +2140,16 @@
   const int tx_height = kTransformHeight[tx_size];
 
   if (is_row) {
-    const int num_rows =
-        (non_zero_coeff_count == 1) ? 1 : std::min(tx_height, 32);
-    if (kShouldRound[tx_size]) {
+    const bool should_round = kShouldRound[tx_size];
+    const uint8_t row_shift = kTransformRowShift[tx_size];
+
+    if (DctDcOnly<16>(&src[0], &src[0], non_zero_coeff_count, should_round,
+                      row_shift)) {
+      return;
+    }
+
+    const int num_rows = std::min(tx_height, 32);
+    if (should_round) {
       ApplyRounding<16>(src, num_rows);
     }
 
@@ -2113,7 +2166,6 @@
         i += 8;
       } while (i < num_rows);
     }
-    const uint8_t row_shift = kTransformRowShift[tx_size];
     // row_shift is always non zero here.
     RowShift<16>(src, num_rows, row_shift);
 
@@ -2151,9 +2203,16 @@
   const int tx_height = kTransformHeight[tx_size];
 
   if (is_row) {
-    const int num_rows =
-        (non_zero_coeff_count == 1) ? 1 : std::min(tx_height, 32);
-    if (kShouldRound[tx_size]) {
+    const bool should_round = kShouldRound[tx_size];
+    const uint8_t row_shift = kTransformRowShift[tx_size];
+
+    if (DctDcOnly<32>(&src[0], &src[0], non_zero_coeff_count, should_round,
+                      row_shift)) {
+      return;
+    }
+
+    const int num_rows = std::min(tx_height, 32);
+    if (should_round) {
       ApplyRounding<32>(src, num_rows);
     }
     // Process 8 1d dct32 rows in parallel per iteration.
@@ -2162,7 +2221,7 @@
       Dct32_NEON(&src[i * 32], &src[i * 32], 32, /*transpose=*/true);
       i += 8;
     } while (i < num_rows);
-    const uint8_t row_shift = kTransformRowShift[tx_size];
+
     // row_shift is always non zero here.
     RowShift<32>(src, num_rows, row_shift);
 
@@ -2189,9 +2248,16 @@
   const int tx_height = kTransformHeight[tx_size];
 
   if (is_row) {
-    const int num_rows =
-        (non_zero_coeff_count == 1) ? 1 : std::min(tx_height, 32);
-    if (kShouldRound[tx_size]) {
+    const bool should_round = kShouldRound[tx_size];
+    const uint8_t row_shift = kTransformRowShift[tx_size];
+
+    if (DctDcOnly<64>(&src[0], &src[0], non_zero_coeff_count, should_round,
+                      row_shift)) {
+      return;
+    }
+
+    const int num_rows = std::min(tx_height, 32);
+    if (should_round) {
       ApplyRounding<64>(src, num_rows);
     }
     // Process 8 1d dct64 rows in parallel per iteration.
@@ -2200,7 +2266,6 @@
       Dct64_NEON(&src[i * 64], &src[i * 64], 64, /*transpose=*/true);
       i += 8;
     } while (i < num_rows);
-    const uint8_t row_shift = kTransformRowShift[tx_size];
     // row_shift is always non zero here.
     RowShift<64>(src, num_rows, row_shift);
 
diff --git a/libgav1/src/dsp/arm/loop_restoration_neon.cc b/libgav1/src/dsp/arm/loop_restoration_neon.cc
index fd0eda5..62e49ae 100644
--- a/libgav1/src/dsp/arm/loop_restoration_neon.cc
+++ b/libgav1/src/dsp/arm/loop_restoration_neon.cc
@@ -64,17 +64,17 @@
   sum = vaddq_s16(sum, vreinterpretq_s16_u16(vshll_n_u8(a[3], 4)));
 
   // Saturate to
-  // [0, (1 << (bitdepth + 1 + kWienerFilterBits - kInterRoundBitsHorizontal)) -
-  //  1)]
-  //     (1 << (       8 + 1 +                 7 -                   3)) - 1)
+  // [0,
+  // (1 << (bitdepth + 1 + kWienerFilterBits - kInterRoundBitsHorizontal)) - 1)]
+  // (1 << (       8 + 1 +                 7 -                         3)) - 1)
   sum = vminq_s16(sum, vdupq_n_s16((1 << 13) - 1));
   sum = vmaxq_s16(sum, vdupq_n_s16(0));
   return sum;
 }
 
 template <int min_width>
-inline void VerticalSum(const int16_t* src_base, const int src_stride,
-                        uint8_t* dst_base, const int dst_stride,
+inline void VerticalSum(const int16_t* src_base, const ptrdiff_t src_stride,
+                        uint8_t* dst_base, const ptrdiff_t dst_stride,
                         const int16x4_t filter[7], const int width,
                         const int height) {
   static_assert(min_width == 4 || min_width == 8, "");
@@ -202,76 +202,31 @@
   // left value.
   const int center_tap = 3;
   src -= center_tap * source_stride + center_tap;
-  // This writes out 2 more rows than we need. It always writes out at least
-  // width 8 for the intermediate buffer.
-  // TODO(johannkoenig): Investigate a 4x specific first pass. May be possible
-  // to do it in half the passes.
   int y = 0;
   do {
     int x = 0;
     do {
-      const uint8_t* src_v = src + x;
-      const uint8x16_t a0 = vld1q_u8(src_v);
-      src_v += source_stride;
-      const uint8x16_t a1 = vld1q_u8(src_v);
-      src_v += source_stride;
-      const uint8x16_t a2 = vld1q_u8(src_v);
-      src_v += source_stride;
-      const uint8x16_t a3 = vld1q_u8(src_v);
-      src_v += source_stride;
-      const uint8x16_t a4 = vld1q_u8(src_v);
-      src_v += source_stride;
-      const uint8x16_t a5 = vld1q_u8(src_v);
-      src_v += source_stride;
-      const uint8x16_t a6 = vld1q_u8(src_v);
-      src_v += source_stride;
-      const uint8x16_t a7 = vld1q_u8(src_v);
+      // This is just as fast as an 8x8 transpose but avoids over-reading extra
+      // rows. It always over-reads by at least 1 value. On small widths (4xH)
+      // it over-reads by 9 values.
+      const uint8x16_t src_v = vld1q_u8(src + x);
+      uint8x8_t b[7];
+      b[0] = vget_low_u8(src_v);
+      b[1] = vget_low_u8(vextq_u8(src_v, src_v, 1));
+      b[2] = vget_low_u8(vextq_u8(src_v, src_v, 2));
+      b[3] = vget_low_u8(vextq_u8(src_v, src_v, 3));
+      b[4] = vget_low_u8(vextq_u8(src_v, src_v, 4));
+      b[5] = vget_low_u8(vextq_u8(src_v, src_v, 5));
+      b[6] = vget_low_u8(vextq_u8(src_v, src_v, 6));
 
-      uint8x8_t b[16];
+      int16x8_t sum = HorizontalSum(b, filter);
 
-      // This could load and transpose one 8x8 block to prime the loop, then
-      // load and transpose a second block in the loop. The second block could
-      // be passed to subsequent iterations.
-      // TODO(johannkoenig): convert these to arrays.
-      Transpose16x8(a0, a1, a2, a3, a4, a5, a6, a7, b, b + 1, b + 2, b + 3,
-                    b + 4, b + 5, b + 6, b + 7, b + 8, b + 9, b + 10, b + 11,
-                    b + 12, b + 13, b + 14, b + 15);
-
-      int16x8_t sum_0 = HorizontalSum(b, filter);
-      int16x8_t sum_1 = HorizontalSum(b + 1, filter);
-      int16x8_t sum_2 = HorizontalSum(b + 2, filter);
-      int16x8_t sum_3 = HorizontalSum(b + 3, filter);
-      int16x8_t sum_4 = HorizontalSum(b + 4, filter);
-      int16x8_t sum_5 = HorizontalSum(b + 5, filter);
-      int16x8_t sum_6 = HorizontalSum(b + 6, filter);
-      int16x8_t sum_7 = HorizontalSum(b + 7, filter);
-
-      // TODO(johannkoenig): convert this to an array.
-      Transpose8x8(&sum_0, &sum_1, &sum_2, &sum_3, &sum_4, &sum_5, &sum_6,
-                   &sum_7);
-
-      int16_t* wiener_buffer_v = wiener_buffer + x;
-      vst1q_s16(wiener_buffer_v, sum_0);
-      wiener_buffer_v += buffer_stride;
-      vst1q_s16(wiener_buffer_v, sum_1);
-      wiener_buffer_v += buffer_stride;
-      vst1q_s16(wiener_buffer_v, sum_2);
-      wiener_buffer_v += buffer_stride;
-      vst1q_s16(wiener_buffer_v, sum_3);
-      wiener_buffer_v += buffer_stride;
-      vst1q_s16(wiener_buffer_v, sum_4);
-      wiener_buffer_v += buffer_stride;
-      vst1q_s16(wiener_buffer_v, sum_5);
-      wiener_buffer_v += buffer_stride;
-      vst1q_s16(wiener_buffer_v, sum_6);
-      wiener_buffer_v += buffer_stride;
-      vst1q_s16(wiener_buffer_v, sum_7);
+      vst1q_s16(wiener_buffer + x, sum);
       x += 8;
     } while (x < width);
-    src += 8 * source_stride;
-    wiener_buffer += 8 * buffer_stride;
-    y += 8;
-  } while (y < height + kSubPixelTaps - 2);
+    src += source_stride;
+    wiener_buffer += buffer_stride;
+  } while (++y < height + kSubPixelTaps - 2);
 
   // Vertical filtering.
   wiener_buffer = reinterpret_cast<int16_t*>(buffer->wiener_buffer);
diff --git a/libgav1/src/dsp/arm/loop_restoration_neon.h b/libgav1/src/dsp/arm/loop_restoration_neon.h
index 1723c50..68f9526 100644
--- a/libgav1/src/dsp/arm/loop_restoration_neon.h
+++ b/libgav1/src/dsp/arm/loop_restoration_neon.h
@@ -16,8 +16,8 @@
 
 #if LIBGAV1_ENABLE_NEON
 
-#define LIBGAV1_Dsp8bpp_SelfGuidedFilter LIBGAV1_DSP_NEON
 #define LIBGAV1_Dsp8bpp_WienerFilter LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_SelfGuidedFilter LIBGAV1_DSP_NEON
 
 #endif  // LIBGAV1_ENABLE_NEON
 
diff --git a/libgav1/src/dsp/arm/warp_neon.cc b/libgav1/src/dsp/arm/warp_neon.cc
index 7ab3d27..f5673b0 100644
--- a/libgav1/src/dsp/arm/warp_neon.cc
+++ b/libgav1/src/dsp/arm/warp_neon.cc
@@ -27,7 +27,7 @@
 void Warp_NEON(const void* const source, const ptrdiff_t source_stride,
                const int source_width, const int source_height,
                const int* const warp_params, const int subsampling_x,
-               const int subsampling_y, const uint8_t inter_round_bits_vertical,
+               const int subsampling_y, const int inter_round_bits_vertical,
                const int block_start_x, const int block_start_y,
                const int block_width, const int block_height,
                const int16_t alpha, const int16_t beta, const int16_t gamma,
@@ -48,9 +48,6 @@
   assert(block_width >= 8);
   assert(block_height >= 8);
 
-  const uint8x16_t index_vec = vcombine_u8(vcreate_u8(0x0706050403020100),
-                                           vcreate_u8(0x0f0e0d0c0b0a0908));
-
   // Warp process applies for each 8x8 block (or smaller).
   int start_y = block_start_y;
   do {
@@ -87,7 +84,7 @@
               (horizontal_offset >> kInterRoundBitsHorizontal) +
               (src_row[source_width - 1] << (7 - kInterRoundBitsHorizontal));
           const int16x8_t sum = vdupq_n_s16(s);
-          vst1q_s16(&intermediate_result[y + 7][0], sum);
+          vst1q_s16(intermediate_result[y + 7], sum);
           sx4 += beta;
           continue;
         }
@@ -98,7 +95,7 @@
           const int16_t s = (horizontal_offset >> kInterRoundBitsHorizontal) +
                             (src_row[0] << (7 - kInterRoundBitsHorizontal));
           const int16x8_t sum = vdupq_n_s16(s);
-          vst1q_s16(&intermediate_result[y + 7][0], sum);
+          vst1q_s16(intermediate_result[y + 7], sum);
           sx4 += beta;
           continue;
         }
@@ -106,31 +103,11 @@
         // read but is ignored.
         //
         // NOTE: This may read up to 13 bytes before src_row[0] or up to 14
-        // bytes after src_row[source_width - 1]. There must be enough padding
-        // before and after the |source| buffer.
-        static_assert(kBorderPixels >= 14, "");
-        uint8x16_t src_row_u8 = vld1q_u8(&src_row[ix4 - 7]);
-        // If clipping is needed, duplicate the border samples.
-        //
-        // Here is the correspondence between the index for the src_row
-        // buffer and the index for the src_row_u8 vector:
-        //
-        // src_row index    : (ix4 - 7) (ix4 - 6) ... (ix4 + 6) (ix4 + 7)
-        // src_row_u8 index :     0         1     ...     13        14
-        if (ix4 - 7 < 0) {
-          const int out_of_boundary_left = -(ix4 - 7);
-          const uint8x16_t cmp_vec = vdupq_n_u8(out_of_boundary_left);
-          const uint8x16_t vec_dup = vdupq_n_u8(src_row[0]);
-          const uint8x16_t mask_val = vcltq_u8(index_vec, cmp_vec);
-          src_row_u8 = vbslq_u8(mask_val, vec_dup, src_row_u8);
-        }
-        if (ix4 + 7 > source_width - 1) {
-          const int out_of_boundary_right = ix4 + 8 - source_width;
-          const uint8x16_t cmp_vec = vdupq_n_u8(14 - out_of_boundary_right);
-          const uint8x16_t vec_dup = vdupq_n_u8(src_row[source_width - 1]);
-          const uint8x16_t mask_val = vcgtq_u8(index_vec, cmp_vec);
-          src_row_u8 = vbslq_u8(mask_val, vec_dup, src_row_u8);
-        }
+        // bytes after src_row[source_width - 1]. We assume the source frame
+        // has left and right borders of at least 13 bytes that extend the
+        // frame boundary pixels. We also assume there is at least one extra
+        // padding byte after the right border of the last source row.
+        const uint8x16_t src_row_u8 = vld1q_u8(&src_row[ix4 - 7]);
         const int16x8_t src_row_low_s16 =
             vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(src_row_u8)));
         const int16x8_t src_row_high_s16 =
@@ -201,7 +178,7 @@
         // Treat sum as unsigned for the right shift.
         sum = vreinterpretq_s16_u16(vrshrq_n_u16(vreinterpretq_u16_s16(sum),
                                                  kInterRoundBitsHorizontal));
-        vst1q_s16(&intermediate_result[y + 7][0], sum);
+        vst1q_s16(intermediate_result[y + 7], sum);
         sx4 += beta;
       }
 
diff --git a/libgav1/src/dsp/average_blend.cc b/libgav1/src/dsp/average_blend.cc
index a5de6d3..438501a 100644
--- a/libgav1/src/dsp/average_blend.cc
+++ b/libgav1/src/dsp/average_blend.cc
@@ -28,19 +28,22 @@
   auto* dst = static_cast<Pixel*>(dest);
   const ptrdiff_t dst_stride = dest_stride / sizeof(Pixel);
 
-  for (int y = 0; y < height; ++y) {
-    for (int x = 0; x < width; ++x) {
+  int y = 0;
+  do {
+    int x = 0;
+    do {
       // prediction range: 8bpp: [0, 15471] 10bpp: [0, 61983] 12bpp: [0, 62007].
       int res = prediction_0[x] + prediction_1[x];
       res -= compound_round_offset;
       dst[x] = static_cast<Pixel>(
           Clip3(RightShiftWithRounding(res, inter_post_round_bits + 1), 0,
                 (1 << bitdepth) - 1));
-    }
+    } while (++x < width);
+
     dst += dst_stride;
     prediction_0 += prediction_stride_0;
     prediction_1 += prediction_stride_1;
-  }
+  } while (++y < height);
 }
 
 void Init8bpp() {
diff --git a/libgav1/src/dsp/cdef.cc b/libgav1/src/dsp/cdef.cc
index e412342..84dc778 100644
--- a/libgav1/src/dsp/cdef.cc
+++ b/libgav1/src/dsp/cdef.cc
@@ -115,8 +115,10 @@
   const auto* src = static_cast<const uint16_t*>(source);
   auto* dst = static_cast<Pixel*>(dest);
   const ptrdiff_t dst_stride = dest_stride / sizeof(Pixel);
-  for (int y = 0; y < block_height; ++y) {
-    for (int x = 0; x < block_width; ++x) {
+  int y = 0;
+  do {
+    int x = 0;
+    do {
       int16_t sum = 0;
       const uint16_t pixel_value = src[x];
       uint16_t max_value = pixel_value;
@@ -156,10 +158,11 @@
 
       dst[x] = static_cast<Pixel>(Clip3(
           pixel_value + ((8 + sum - (sum < 0)) >> 4), min_value, max_value));
-    }
+    } while (++x < block_width);
+
     src += source_stride;
     dst += dst_stride;
-  }
+  } while (++y < block_height);
 }
 
 void Init8bpp() {
diff --git a/libgav1/src/dsp/convolve.cc b/libgav1/src/dsp/convolve.cc
index 8c409c8..350196a 100644
--- a/libgav1/src/dsp/convolve.cc
+++ b/libgav1/src/dsp/convolve.cc
@@ -36,7 +36,7 @@
 void ConvolveScale2D_C(
     const void* const reference, const ptrdiff_t reference_stride,
     const int horizontal_filter_index, const int vertical_filter_index,
-    const uint8_t inter_round_bits_vertical, const int subpixel_x,
+    const int inter_round_bits_vertical, const int subpixel_x,
     const int subpixel_y, const int step_x, const int step_y, const int width,
     const int height, void* prediction, const ptrdiff_t pred_stride) {
   constexpr int kRoundBitsHorizontal = (bitdepth == 12)
@@ -68,8 +68,11 @@
   const int ref_x = subpixel_x >> kScaleSubPixelBits;
   // Note: assume the input src is already aligned to the correct start
   // position.
-  for (int y = 0; y < intermediate_height; ++y) {
-    for (int x = 0, p = subpixel_x; x < width; ++x, p += step_x) {
+  int y = 0;
+  do {
+    int p = subpixel_x;
+    int x = 0;
+    do {
       // An offset to guarantee the sum is non negative.
       int sum = 1 << (bitdepth + kFilterBits - 1);
       const Pixel* src_x = &src[(p >> kScaleSubPixelBits) - ref_x];
@@ -80,17 +83,23 @@
       assert(sum >= 0 && sum < (1 << (bitdepth + kFilterBits + 1)));
       intermediate[x] = static_cast<int16_t>(
           RightShiftWithRounding(sum, kRoundBitsHorizontal));
-    }
+      p += step_x;
+    } while (++x < width);
+
     src += src_stride;
     intermediate += intermediate_stride;
-  }
+  } while (++y < intermediate_height);
+
   // Vertical filter.
   filter_index = GetFilterIndex(vertical_filter_index, height);
   intermediate = intermediate_result;
   const int offset_bits = bitdepth + 2 * kFilterBits - kRoundBitsHorizontal;
-  for (int y = 0, p = subpixel_y & 1023; y < height; ++y, p += step_y) {
+  int p = subpixel_y & 1023;
+  y = 0;
+  do {
     const int filter_id = (p >> 6) & kSubPixelMask;
-    for (int x = 0; x < width; ++x) {
+    int x = 0;
+    do {
       // An offset to guarantee the sum is non negative.
       int sum = 1 << offset_bits;
       for (int k = 0; k < kSubPixelTaps; ++k) {
@@ -104,16 +113,18 @@
           Clip3(RightShiftWithRounding(sum, inter_round_bits_vertical) -
                     single_round_offset,
                 0, max_pixel_value));
-    }
+    } while (++x < width);
+
     dest += dest_stride;
-  }
+    p += step_y;
+  } while (++y < height);
 }
 
 template <int bitdepth, typename Pixel>
 void ConvolveCompoundScale2D_C(
     const void* const reference, const ptrdiff_t reference_stride,
     const int horizontal_filter_index, const int vertical_filter_index,
-    const uint8_t inter_round_bits_vertical, const int subpixel_x,
+    const int inter_round_bits_vertical, const int subpixel_x,
     const int subpixel_y, const int step_x, const int step_y, const int width,
     const int height, void* prediction, const ptrdiff_t pred_stride) {
   constexpr int kRoundBitsHorizontal = (bitdepth == 12)
@@ -142,8 +153,11 @@
   const int ref_x = subpixel_x >> kScaleSubPixelBits;
   // Note: assume the input src is already aligned to the correct start
   // position.
-  for (int y = 0; y < intermediate_height; ++y) {
-    for (int x = 0, p = subpixel_x; x < width; ++x, p += step_x) {
+  int y = 0;
+  do {
+    int p = subpixel_x;
+    int x = 0;
+    do {
       // An offset to guarantee the sum is non negative.
       int sum = 1 << (bitdepth + kFilterBits - 1);
       const Pixel* src_x = &src[(p >> kScaleSubPixelBits) - ref_x];
@@ -154,17 +168,23 @@
       assert(sum >= 0 && sum < (1 << (bitdepth + kFilterBits + 1)));
       intermediate[x] = static_cast<int16_t>(
           RightShiftWithRounding(sum, kRoundBitsHorizontal));
-    }
+      p += step_x;
+    } while (++x < width);
+
     src += src_stride;
     intermediate += intermediate_stride;
-  }
+  } while (++y < intermediate_height);
+
   // Vertical filter.
   filter_index = GetFilterIndex(vertical_filter_index, height);
   intermediate = intermediate_result;
   const int offset_bits = bitdepth + 2 * kFilterBits - kRoundBitsHorizontal;
-  for (int y = 0, p = subpixel_y & 1023; y < height; ++y, p += step_y) {
+  int p = subpixel_y & 1023;
+  y = 0;
+  do {
     const int filter_id = (p >> 6) & kSubPixelMask;
-    for (int x = 0; x < width; ++x) {
+    int x = 0;
+    do {
       // An offset to guarantee the sum is non negative.
       int sum = 1 << offset_bits;
       for (int k = 0; k < kSubPixelTaps; ++k) {
@@ -176,9 +196,11 @@
       assert(sum >= 0 && sum < (1 << (offset_bits + 2)));
       dest[x] = static_cast<uint16_t>(
           RightShiftWithRounding(sum, inter_round_bits_vertical));
-    }
+    } while (++x < width);
+
     dest += pred_stride;
-  }
+    p += step_y;
+  } while (++y < height);
 }
 
 template <int bitdepth, typename Pixel>
@@ -186,7 +208,7 @@
                           const ptrdiff_t reference_stride,
                           const int horizontal_filter_index,
                           const int vertical_filter_index,
-                          const uint8_t inter_round_bits_vertical,
+                          const int inter_round_bits_vertical,
                           const int subpixel_x, const int subpixel_y,
                           const int /*step_x*/, const int /*step_y*/,
                           const int width, const int height, void* prediction,
@@ -213,8 +235,10 @@
                     kVerticalOffset * src_stride - kHorizontalOffset;
   auto* dest = static_cast<uint16_t*>(prediction);
   int filter_id = (subpixel_x >> 6) & kSubPixelMask;
-  for (int y = 0; y < intermediate_height; ++y) {
-    for (int x = 0; x < width; ++x) {
+  int y = 0;
+  do {
+    int x = 0;
+    do {
       // An offset to guarantee the sum is non negative.
       int sum = 1 << (bitdepth + kFilterBits - 1);
       for (int k = 0; k < kSubPixelTaps; ++k) {
@@ -223,17 +247,21 @@
       assert(sum >= 0 && sum < (1 << (bitdepth + kFilterBits + 1)));
       intermediate[x] = static_cast<int16_t>(
           RightShiftWithRounding(sum, kRoundBitsHorizontal));
-    }
+    } while (++x < width);
+
     src += src_stride;
     intermediate += intermediate_stride;
-  }
+  } while (++y < intermediate_height);
+
   // Vertical filter.
   filter_index = GetFilterIndex(vertical_filter_index, height);
   intermediate = intermediate_result;
   filter_id = ((subpixel_y & 1023) >> 6) & kSubPixelMask;
   const int offset_bits = bitdepth + 2 * kFilterBits - kRoundBitsHorizontal;
-  for (int y = 0; y < height; ++y) {
-    for (int x = 0; x < width; ++x) {
+  y = 0;
+  do {
+    int x = 0;
+    do {
       // An offset to guarantee the sum is non negative.
       int sum = 1 << offset_bits;
       for (int k = 0; k < kSubPixelTaps; ++k) {
@@ -243,10 +271,11 @@
       assert(sum >= 0 && sum < (1 << (offset_bits + 2)));
       dest[x] = static_cast<uint16_t>(
           RightShiftWithRounding(sum, inter_round_bits_vertical));
-    }
+    } while (++x < width);
+
     dest += pred_stride;
     intermediate += intermediate_stride;
-  }
+  } while (++y < height);
 }
 
 // This function is a simplified version of ConvolveCompound2D_C.
@@ -258,7 +287,7 @@
 void Convolve2D_C(const void* const reference, const ptrdiff_t reference_stride,
                   const int horizontal_filter_index,
                   const int vertical_filter_index,
-                  const uint8_t inter_round_bits_vertical, const int subpixel_x,
+                  const int inter_round_bits_vertical, const int subpixel_x,
                   const int subpixel_y, const int /*step_x*/,
                   const int /*step_y*/, const int width, const int height,
                   void* prediction, const ptrdiff_t pred_stride) {
@@ -287,8 +316,10 @@
   auto* dest = static_cast<Pixel*>(prediction);
   const ptrdiff_t dest_stride = pred_stride / sizeof(Pixel);
   int filter_id = (subpixel_x >> 6) & kSubPixelMask;
-  for (int y = 0; y < intermediate_height; ++y) {
-    for (int x = 0; x < width; ++x) {
+  int y = 0;
+  do {
+    int x = 0;
+    do {
       // An offset to guarantee the sum is non negative.
       int sum = 1 << (bitdepth + kFilterBits - 1);
       for (int k = 0; k < kSubPixelTaps; ++k) {
@@ -297,17 +328,21 @@
       assert(sum >= 0 && sum < (1 << (bitdepth + kFilterBits + 1)));
       intermediate[x] = static_cast<int16_t>(
           RightShiftWithRounding(sum, kRoundBitsHorizontal));
-    }
+    } while (++x < width);
+
     src += src_stride;
     intermediate += intermediate_stride;
-  }
+  } while (++y < intermediate_height);
+
   // Vertical filter.
   filter_index = GetFilterIndex(vertical_filter_index, height);
   intermediate = intermediate_result;
   filter_id = ((subpixel_y & 1023) >> 6) & kSubPixelMask;
   const int offset_bits = bitdepth + 2 * kFilterBits - kRoundBitsHorizontal;
-  for (int y = 0; y < height; ++y) {
-    for (int x = 0; x < width; ++x) {
+  y = 0;
+  do {
+    int x = 0;
+    do {
       // An offset to guarantee the sum is non negative.
       int sum = 1 << offset_bits;
       for (int k = 0; k < kSubPixelTaps; ++k) {
@@ -319,10 +354,11 @@
           Clip3(RightShiftWithRounding(sum, inter_round_bits_vertical) -
                     single_round_offset,
                 0, max_pixel_value));
-    }
+    } while (++x < width);
+
     dest += dest_stride;
     intermediate += intermediate_stride;
-  }
+  } while (++y < height);
 }
 
 // This function is a simplified version of Convolve2D_C.
@@ -335,7 +371,7 @@
                           const ptrdiff_t reference_stride,
                           const int horizontal_filter_index,
                           const int /*vertical_filter_index*/,
-                          const uint8_t /*inter_round_bits_vertical*/,
+                          const int /*inter_round_bits_vertical*/,
                           const int subpixel_x, const int /*subpixel_y*/,
                           const int /*step_x*/, const int /*step_y*/,
                           const int width, const int height, void* prediction,
@@ -351,8 +387,10 @@
   const ptrdiff_t dest_stride = pred_stride / sizeof(Pixel);
   const int filter_id = (subpixel_x >> 6) & kSubPixelMask;
   const int max_pixel_value = (1 << bitdepth) - 1;
-  for (int y = 0; y < height; ++y) {
-    for (int x = 0; x < width; ++x) {
+  int y = 0;
+  do {
+    int x = 0;
+    do {
       int sum = 0;
       for (int k = 0; k < kSubPixelTaps; ++k) {
         sum += kSubPixelFilters[filter_index][filter_id][k] * src[x + k];
@@ -360,10 +398,11 @@
       sum = RightShiftWithRounding(sum, kRoundBitsHorizontal);
       dest[x] = static_cast<Pixel>(
           Clip3(RightShiftWithRounding(sum, bits), 0, max_pixel_value));
-    }
+    } while (++x < width);
+
     src += src_stride;
     dest += dest_stride;
-  }
+  } while (++y < height);
 }
 
 // This function is a simplified version of Convolve2D_C.
@@ -376,7 +415,7 @@
                         const ptrdiff_t reference_stride,
                         const int /*horizontal_filter_index*/,
                         const int vertical_filter_index,
-                        const uint8_t /*inter_round_bits_vertical*/,
+                        const int /*inter_round_bits_vertical*/,
                         const int /*subpixel_x*/, const int subpixel_y,
                         const int /*step_x*/, const int /*step_y*/,
                         const int width, const int height, void* prediction,
@@ -401,8 +440,10 @@
     return;
   }
   const int max_pixel_value = (1 << bitdepth) - 1;
-  for (int y = 0; y < height; ++y) {
-    for (int x = 0; x < width; ++x) {
+  int y = 0;
+  do {
+    int x = 0;
+    do {
       int sum = 0;
       for (int k = 0; k < kSubPixelTaps; ++k) {
         sum += kSubPixelFilters[filter_index][filter_id][k] *
@@ -410,10 +451,11 @@
       }
       dest[x] = static_cast<Pixel>(
           Clip3(RightShiftWithRounding(sum, kFilterBits), 0, max_pixel_value));
-    }
+    } while (++x < width);
+
     src += src_stride;
     dest += dest_stride;
-  }
+  } while (++y < height);
 }
 
 template <int bitdepth, typename Pixel>
@@ -421,18 +463,19 @@
                     const ptrdiff_t reference_stride,
                     const int /*horizontal_filter_index*/,
                     const int /*vertical_filter_index*/,
-                    const uint8_t /*inter_round_bits_vertical*/,
+                    const int /*inter_round_bits_vertical*/,
                     const int /*subpixel_x*/, const int /*subpixel_y*/,
                     const int /*step_x*/, const int /*step_y*/, const int width,
                     const int height, void* prediction,
                     const ptrdiff_t pred_stride) {
   const auto* src = static_cast<const uint8_t*>(reference);
   auto* dest = static_cast<uint8_t*>(prediction);
-  for (int y = 0; y < height; ++y) {
+  int y = 0;
+  do {
     memcpy(dest, src, width * sizeof(Pixel));
     src += reference_stride;
     dest += pred_stride;
-  }
+  } while (++y < height);
 }
 
 template <int bitdepth, typename Pixel>
@@ -440,7 +483,7 @@
                             const ptrdiff_t reference_stride,
                             const int /*horizontal_filter_index*/,
                             const int /*vertical_filter_index*/,
-                            const uint8_t /*inter_round_bits_vertical*/,
+                            const int /*inter_round_bits_vertical*/,
                             const int /*subpixel_x*/, const int /*subpixel_y*/,
                             const int /*step_x*/, const int /*step_y*/,
                             const int width, const int height, void* prediction,
@@ -450,13 +493,16 @@
   auto* dest = static_cast<uint16_t*>(prediction);
   const int compound_round_offset =
       (1 << (bitdepth + 4)) + (1 << (bitdepth + 3));
-  for (int y = 0; y < height; ++y) {
-    for (int x = 0; x < width; ++x) {
+  int y = 0;
+  do {
+    int x = 0;
+    do {
       dest[x] = (src[x] << 4) + compound_round_offset;
-    }
+    } while (++x < width);
+
     src += src_stride;
     dest += pred_stride;
-  }
+  } while (++y < height);
 }
 
 // This function is a simplified version of ConvolveCompound2D_C.
@@ -468,7 +514,7 @@
 void ConvolveCompoundHorizontal_C(
     const void* const reference, const ptrdiff_t reference_stride,
     const int horizontal_filter_index, const int /*vertical_filter_index*/,
-    const uint8_t inter_round_bits_vertical, const int subpixel_x,
+    const int inter_round_bits_vertical, const int subpixel_x,
     const int /*subpixel_y*/, const int /*step_x*/, const int /*step_y*/,
     const int width, const int height, void* prediction,
     const ptrdiff_t pred_stride) {
@@ -483,18 +529,21 @@
   const int bits_shift = kFilterBits - inter_round_bits_vertical;
   const int compound_round_offset =
       (1 << (bitdepth + 4)) + (1 << (bitdepth + 3));
-  for (int y = 0; y < height; ++y) {
-    for (int x = 0; x < width; ++x) {
+  int y = 0;
+  do {
+    int x = 0;
+    do {
       int sum = 0;
       for (int k = 0; k < kSubPixelTaps; ++k) {
         sum += kSubPixelFilters[filter_index][filter_id][k] * src[x + k];
       }
       sum = RightShiftWithRounding(sum, kRoundBitsHorizontal) << bits_shift;
       dest[x] = sum + compound_round_offset;
-    }
+    } while (++x < width);
+
     src += src_stride;
     dest += pred_stride;
-  }
+  } while (++y < height);
 }
 
 // This function is a simplified version of ConvolveCompound2D_C.
@@ -507,7 +556,7 @@
                                 const ptrdiff_t reference_stride,
                                 const int /*horizontal_filter_index*/,
                                 const int vertical_filter_index,
-                                const uint8_t inter_round_bits_vertical,
+                                const int inter_round_bits_vertical,
                                 const int /*subpixel_x*/, const int subpixel_y,
                                 const int /*step_x*/, const int /*step_y*/,
                                 const int width, const int height,
@@ -524,8 +573,10 @@
   const int bits_shift = kFilterBits - kRoundBitsHorizontal;
   const int compound_round_offset =
       (1 << (bitdepth + 4)) + (1 << (bitdepth + 3));
-  for (int y = 0; y < height; ++y) {
-    for (int x = 0; x < width; ++x) {
+  int y = 0;
+  do {
+    int x = 0;
+    do {
       int sum = 0;
       for (int k = 0; k < kSubPixelTaps; ++k) {
         sum += kSubPixelFilters[filter_index][filter_id][k] *
@@ -534,10 +585,11 @@
       dest[x] = RightShiftWithRounding(LeftShift(sum, bits_shift),
                                        inter_round_bits_vertical) +
                 compound_round_offset;
-    }
+    } while (++x < width);
+
     src += src_stride;
     dest += pred_stride;
-  }
+  } while (++y < height);
 }
 
 // This function is used when intra block copy is present.
@@ -550,7 +602,7 @@
 void ConvolveIntraBlockCopy2D_C(
     const void* const reference, const ptrdiff_t reference_stride,
     const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
-    const uint8_t /*inter_round_bits_vertical*/, const int /*subpixel_x*/,
+    const int /*inter_round_bits_vertical*/, const int /*subpixel_x*/,
     const int /*subpixel_y*/, const int /*step_x*/, const int /*step_y*/,
     const int width, const int height, void* prediction,
     const ptrdiff_t pred_stride) {
@@ -565,22 +617,29 @@
   // Note: allow vertical access to height + 1. Because this function is only
   // for u/v plane of intra block copy, such access is guaranteed to be within
   // the prediction block.
-  for (int y = 0; y < intermediate_height; ++y) {
-    for (int x = 0; x < width; ++x) {
+  int y = 0;
+  do {
+    int x = 0;
+    do {
       intermediate[x] = src[x] + src[x + 1];
-    }
+    } while (++x < width);
+
     src += src_stride;
     intermediate += width;
-  }
+  } while (++y < intermediate_height);
+
   intermediate = intermediate_result;
-  for (int y = 0; y < height; ++y) {
-    for (int x = 0; x < width; ++x) {
+  y = 0;
+  do {
+    int x = 0;
+    do {
       dest[x] = static_cast<Pixel>(
           RightShiftWithRounding(intermediate[x] + intermediate[x + width], 2));
-    }
+    } while (++x < width);
+
     intermediate += width;
     dest += dest_stride;
-  }
+  } while (++y < height);
 }
 
 // This function is used when intra block copy is present.
@@ -595,7 +654,7 @@
 void ConvolveIntraBlockCopy1D_C(
     const void* const reference, const ptrdiff_t reference_stride,
     const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
-    const uint8_t /*inter_round_bits_vertical*/, const int /*subpixel_x*/,
+    const int /*inter_round_bits_vertical*/, const int /*subpixel_x*/,
     const int /*subpixel_y*/, const int /*step_x*/, const int /*step_y*/,
     const int width, const int height, void* prediction,
     const ptrdiff_t pred_stride) {
@@ -604,14 +663,17 @@
   auto* dest = reinterpret_cast<Pixel*>(prediction);
   const ptrdiff_t dest_stride = pred_stride / sizeof(Pixel);
   const ptrdiff_t offset = is_horizontal ? 1 : src_stride;
-  for (int y = 0; y < height; ++y) {
-    for (int x = 0; x < width; ++x) {
+  int y = 0;
+  do {
+    int x = 0;
+    do {
       dest[x] = static_cast<Pixel>(
           RightShiftWithRounding(src[x] + src[x + offset], 1));
-    }
+    } while (++x < width);
+
     src += src_stride;
     dest += dest_stride;
-  }
+  } while (++y < height);
 }
 
 void Init8bpp() {
diff --git a/libgav1/src/dsp/distance_weighted_blend.cc b/libgav1/src/dsp/distance_weighted_blend.cc
index 05e8e3f..7213209 100644
--- a/libgav1/src/dsp/distance_weighted_blend.cc
+++ b/libgav1/src/dsp/distance_weighted_blend.cc
@@ -29,8 +29,10 @@
   auto* dst = static_cast<Pixel*>(dest);
   const ptrdiff_t dst_stride = dest_stride / sizeof(Pixel);
 
-  for (int y = 0; y < height; ++y) {
-    for (int x = 0; x < width; ++x) {
+  int y = 0;
+  do {
+    int x = 0;
+    do {
       // prediction range: 8bpp: [0, 15471] 10bpp: [0, 61983] 12bpp: [0, 62007].
       // weight_0 + weight_1 = 16.
       int res = prediction_0[x] * weight_0 + prediction_1[x] * weight_1;
@@ -38,11 +40,12 @@
       dst[x] = static_cast<Pixel>(
           Clip3(RightShiftWithRounding(res, inter_post_round_bits + 4), 0,
                 (1 << bitdepth) - 1));
-    }
+    } while (++x < width);
+
     dst += dst_stride;
     prediction_0 += prediction_stride_0;
     prediction_1 += prediction_stride_1;
-  }
+  } while (++y < height);
 }
 
 void Init8bpp() {
diff --git a/libgav1/src/dsp/dsp.h b/libgav1/src/dsp/dsp.h
index 631a42c..48b3371 100644
--- a/libgav1/src/dsp/dsp.h
+++ b/libgav1/src/dsp/dsp.h
@@ -385,10 +385,10 @@
 using ConvolveFunc = void (*)(const void* reference, ptrdiff_t reference_stride,
                               int vertical_filter_index,
                               int horizontal_filter_index,
-                              const uint8_t inter_round_bits_vertical,
-                              int subpixel_x, int subpixel_y, int step_x,
-                              int step_y, int width, int height,
-                              void* prediction, ptrdiff_t pred_stride);
+                              int inter_round_bits_vertical, int subpixel_x,
+                              int subpixel_y, int step_x, int step_y, int width,
+                              int height, void* prediction,
+                              ptrdiff_t pred_stride);
 
 // Convolve functions signature. Each points to one convolve function with
 // a specific setting:
@@ -533,14 +533,15 @@
 // |dest| is the output buffer. It is a predictor, whose type is int16_t.
 // |dest_stride| is the stride, in units of int16_t.
 //
-// NOTE: The ARM NEON implementation of WarpFunc may read up to 13 bytes before
-// the |source| buffer or up to 14 bytes after the |source| buffer. Therefore,
-// there must be enough padding before and after the |source| buffer.
+// NOTE: WarpFunc assumes the source frame has left and right borders that
+// extend the frame boundary pixels. The left and right borders must be at
+// least 13 pixels wide. In addition, Warp_NEON() may read up to 14 bytes after
+// a row in the |source| buffer. Therefore, there must be at least one extra
+// padding byte after the right border of the last row in the source buffer.
 using WarpFunc = void (*)(const void* source, ptrdiff_t source_stride,
                           int source_width, int source_height,
                           const int* warp_params, int subsampling_x,
-                          int subsampling_y,
-                          const uint8_t inter_round_bits_vertical,
+                          int subsampling_y, int inter_round_bits_vertical,
                           int block_start_x, int block_start_y, int block_width,
                           int block_height, int16_t alpha, int16_t beta,
                           int16_t gamma, int16_t delta, uint16_t* dest,
diff --git a/libgav1/src/dsp/film_grain.cc b/libgav1/src/dsp/film_grain.cc
index b26b517..c7a5a7a 100644
--- a/libgav1/src/dsp/film_grain.cc
+++ b/libgav1/src/dsp/film_grain.cc
@@ -383,15 +383,16 @@
 void FilmGrain<bitdepth>::ApplyAutoRegressiveFilterToLumaGrain(
     const FilmGrainParams& params, int grain_min, int grain_max,
     GrainType* luma_grain) {
+  assert(params.auto_regression_coeff_lag <= 3);
   const int shift = params.auto_regression_shift;
   for (int y = 3; y < kLumaHeight; ++y) {
     for (int x = 3; x < kLumaWidth - 3; ++x) {
       int sum = 0;
       int pos = 0;
-      for (int delta_row = -params.auto_regression_coeff_lag; delta_row <= 0;
-           ++delta_row) {
-        for (int delta_column = -params.auto_regression_coeff_lag;
-             delta_column <= params.auto_regression_coeff_lag; ++delta_column) {
+      int delta_row = -params.auto_regression_coeff_lag;
+      do {
+        int delta_column = -params.auto_regression_coeff_lag;
+        do {
           if (delta_row == 0 && delta_column == 0) {
             break;
           }
@@ -399,8 +400,8 @@
           sum += luma_grain[(y + delta_row) * kLumaWidth + (x + delta_column)] *
                  coeff;
           ++pos;
-        }
-      }
+        } while (++delta_column <= params.auto_regression_coeff_lag);
+      } while (++delta_row <= 0);
       luma_grain[y * kLumaWidth + x] = Clip3(
           luma_grain[y * kLumaWidth + x] + RightShiftWithRounding(sum, shift),
           grain_min, grain_max);
@@ -420,26 +421,34 @@
   } else {
     uint16_t seed = params.grain_seed ^ 0xb524;
     GrainType* u_grain_row = u_grain;
-    for (int y = 0; y < chroma_height; ++y) {
-      for (int x = 0; x < chroma_width; ++x) {
+    assert(chroma_width > 0);
+    assert(chroma_height > 0);
+    int y = 0;
+    do {
+      int x = 0;
+      do {
         u_grain_row[x] = RightShiftWithRounding(
             kGaussianSequence[GetRandomNumber(11, &seed)], shift);
-      }
+      } while (++x < chroma_width);
+
       u_grain_row += chroma_width;
-    }
+    } while (++y < chroma_height);
   }
   if (params.num_v_points == 0 && !params.chroma_scaling_from_luma) {
     memset(v_grain, 0, chroma_height * chroma_width * sizeof(*v_grain));
   } else {
     GrainType* v_grain_row = v_grain;
     uint16_t seed = params.grain_seed ^ 0x49d8;
-    for (int y = 0; y < chroma_height; ++y) {
-      for (int x = 0; x < chroma_width; ++x) {
+    int y = 0;
+    do {
+      int x = 0;
+      do {
         v_grain_row[x] = RightShiftWithRounding(
             kGaussianSequence[GetRandomNumber(11, &seed)], shift);
-      }
+      } while (++x < chroma_width);
+
       v_grain_row += chroma_width;
-    }
+    } while (++y < chroma_height);
   }
 }
 
@@ -449,16 +458,17 @@
     const GrainType* luma_grain, int subsampling_x, int subsampling_y,
     int chroma_width, int chroma_height, GrainType* u_grain,
     GrainType* v_grain) {
+  assert(params.auto_regression_coeff_lag <= 3);
   const int shift = params.auto_regression_shift;
   for (int y = 3; y < chroma_height; ++y) {
     for (int x = 3; x < chroma_width - 3; ++x) {
       int sum_u = 0;
       int sum_v = 0;
       int pos = 0;
-      for (int delta_row = -params.auto_regression_coeff_lag; delta_row <= 0;
-           ++delta_row) {
-        for (int delta_column = -params.auto_regression_coeff_lag;
-             delta_column <= params.auto_regression_coeff_lag; ++delta_column) {
+      int delta_row = -params.auto_regression_coeff_lag;
+      do {
+        int delta_column = -params.auto_regression_coeff_lag;
+        do {
           const int coeff_u = params.auto_regression_coeff_u[pos];
           const int coeff_v = params.auto_regression_coeff_v[pos];
           if (delta_row == 0 && delta_column == 0) {
@@ -466,11 +476,13 @@
               int luma = 0;
               const int luma_x = ((x - 3) << subsampling_x) + 3;
               const int luma_y = ((y - 3) << subsampling_y) + 3;
-              for (int i = 0; i <= subsampling_y; ++i) {
-                for (int j = 0; j <= subsampling_x; ++j) {
+              int i = 0;
+              do {
+                int j = 0;
+                do {
                   luma += luma_grain[(luma_y + i) * kLumaWidth + (luma_x + j)];
-                }
-              }
+                } while (++j <= subsampling_x);
+              } while (++i <= subsampling_y);
               luma =
                   RightShiftWithRounding(luma, subsampling_x + subsampling_y);
               sum_u += luma * coeff_u;
@@ -485,8 +497,8 @@
               v_grain[(y + delta_row) * chroma_width + (x + delta_column)] *
               coeff_v;
           ++pos;
-        }
-      }
+        } while (++delta_column <= params.auto_regression_coeff_lag);
+      } while (++delta_row <= 0);
       u_grain[y * chroma_width + x] = Clip3(
           u_grain[y * chroma_width + x] + RightShiftWithRounding(sum_u, shift),
           grain_min, grain_max);
@@ -556,7 +568,9 @@
   if (noise_buffer_ == nullptr) return false;
   GrainType* noise_block = noise_buffer_.get();
   int luma_num = 0;
-  for (int y = 0; y < half_height; y += 16) {
+  assert(half_height > 0);
+  int y = 0;
+  do {
     noise_stripe_[luma_num][kPlaneY] = noise_block;
     noise_block += 34 * width_;
     if (!is_monochrome_) {
@@ -568,7 +582,8 @@
                      RightShiftWithRounding(width_, subsampling_x_);
     }
     ++luma_num;
-  }
+    y += 16;
+  } while (y < half_height);
   assert(noise_block == noise_buffer_.get() + noise_buffer_size);
   return true;
 }
@@ -579,11 +594,15 @@
   const int half_width = DivideBy2(width_ + 1);
   const int half_height = DivideBy2(height_ + 1);
   int luma_num = 0;
-  for (int y = 0; y < half_height; y += 16) {
+  assert(half_width > 0);
+  assert(half_height > 0);
+  int y = 0;
+  do {
     uint16_t seed = params_.grain_seed;
     seed ^= ((luma_num * 37 + 178) & 255) << 8;
     seed ^= ((luma_num * 173 + 105) & 255);
-    for (int x = 0; x < half_width; x += 16) {
+    int x = 0;
+    do {
       const int rand = GetRandomNumber(8, &seed);
       const int offset_x = rand >> 4;
       const int offset_y = rand & 15;
@@ -596,8 +615,10 @@
             (plane_sub_y != 0) ? 6 + offset_y : 9 + offset_y * 2;
         GrainType* const noise_block = noise_stripe_[luma_num][plane];
         const int noise_block_width = (width_ + plane_sub_x) >> plane_sub_x;
-        for (int i = 0; i < (34 >> plane_sub_y); ++i) {
-          for (int j = 0; j < (34 >> plane_sub_x); ++j) {
+        int i = 0;
+        do {
+          int j = 0;
+          do {
             int grain;
             if (plane == kPlaneY) {
               grain = luma_grain_[(plane_offset_y + i) * kLumaWidth +
@@ -644,12 +665,15 @@
               }
               noise_block[i * noise_block_width + (x + j)] = grain;
             }
-          }
-        }
+          } while (++j < (34 >> plane_sub_x));
+        } while (++i < (34 >> plane_sub_y));
       }
-    }
+      x += 16;
+    } while (x < half_width);
+
     ++luma_num;
-  }
+    y += 16;
+  } while (y < half_height);
 }
 
 template <int bitdepth>
@@ -682,10 +706,12 @@
     const int plane_sub_x = (plane > kPlaneY) ? subsampling_x_ : 0;
     const int plane_sub_y = (plane > kPlaneY) ? subsampling_y_ : 0;
     const int noise_block_width = (width_ + plane_sub_x) >> plane_sub_x;
-    for (int y = 0; y < ((height_ + plane_sub_y) >> plane_sub_y); ++y) {
+    int y = 0;
+    do {
       const int luma_num = y >> (5 - plane_sub_y);
       const int i = y - (luma_num << (5 - plane_sub_y));
-      for (int x = 0; x < noise_block_width; ++x) {
+      int x = 0;
+      do {
         int grain = noise_stripe_[luma_num][plane][i * noise_block_width + x];
         if (plane_sub_y == 0) {
           if (i < 2 && luma_num > 0 && params_.overlap_flag) {
@@ -709,8 +735,8 @@
           }
         }
         noise_image_[plane][y][x] = grain;
-      }
-    }
+      } while (++x < noise_block_width);
+    } while (++y < ((height_ + plane_sub_y) >> plane_sub_y));
   }
 }
 
@@ -750,8 +776,10 @@
     max_chroma = max_luma;
   }
   const int scaling_shift = params_.chroma_scaling;
-  for (int y = 0; y < ((height_ + subsampling_y_) >> subsampling_y_); ++y) {
-    for (int x = 0; x < ((width_ + subsampling_x_) >> subsampling_x_); ++x) {
+  int y = 0;
+  do {
+    int x = 0;
+    do {
       const int luma_x = x << subsampling_x_;
       const int luma_y = y << subsampling_y_;
       const int luma_next_x = std::min(luma_x + 1, width_ - 1);
@@ -806,27 +834,30 @@
       } else {
         out_v[y * dest_stride_v + x] = in_v[y * source_stride_v + x];
       }
-    }
-  }
+    } while (++x < ((width_ + subsampling_x_) >> subsampling_x_));
+  } while (++y < ((height_ + subsampling_y_) >> subsampling_y_));
   if (params_.num_y_points > 0) {
-    for (int y = 0; y < height_; ++y) {
-      for (int x = 0; x < width_; ++x) {
+    int y = 0;
+    do {
+      int x = 0;
+      do {
         const int orig = in_y[y * source_stride_y + x];
         int noise = noise_image_[kPlaneY][y][x];
         noise = RightShiftWithRounding(
             ScaleLut<bitdepth>(scaling_lut_y_, orig) * noise, scaling_shift);
         out_y[y * dest_stride_y + x] = Clip3(orig + noise, min_value, max_luma);
-      }
-    }
+      } while (++x < width_);
+    } while (++y < height_);
   } else if (in_y != out_y) {  // If in_y and out_y point to the same buffer,
                                // then do nothing.
     const Pixel* in_y_row = in_y;
     Pixel* out_y_row = out_y;
-    for (int y = 0; y < height_; ++y) {
+    int y = 0;
+    do {
       memcpy(out_y_row, in_y_row, width_ * sizeof(*out_y_row));
       in_y_row += source_stride_y;
       out_y_row += dest_stride_y;
-    }
+    } while (++y < height_);
   }
 }
 
diff --git a/libgav1/src/dsp/intrapred.cc b/libgav1/src/dsp/intrapred.cc
index 22bde06..b602fc0 100644
--- a/libgav1/src/dsp/intrapred.cc
+++ b/libgav1/src/dsp/intrapred.cc
@@ -374,10 +374,12 @@
   stride /= sizeof(Pixel);
   int row0 = 0, row2 = 2;
   int ystep = 1;
-  for (int y = 0; y < height; y += 2) {
+  int y = 0;
+  do {
     buffer[1][0] = left[y];
     buffer[row2][0] = left[y + 1];
-    for (int x = 1; x < width; x += 4) {
+    int x = 1;
+    do {
       const Pixel p0 = buffer[row0][x - 1];  // top-left
       const Pixel p1 = buffer[row0][x + 0];  // top 0
       const Pixel p2 = buffer[row0][x + 1];  // top 1
@@ -398,7 +400,8 @@
         buffer[1 + yoffset][x + xoffset] = static_cast<Pixel>(
             Clip3(RightShiftWithRounding(value, 4), 0, kMaxPixel));
       }
-    }
+      x += 4;
+    } while (x < width);
     memcpy(dst, &buffer[1][1], width * sizeof(dst[0]));
     dst += stride;
     memcpy(dst, &buffer[row2][1], width * sizeof(dst[0]));
@@ -408,7 +411,8 @@
     row0 ^= 2;
     row2 ^= 2;
     ystep = -ystep;
-  }
+    y += 2;
+  } while (y < height);
 }
 
 //------------------------------------------------------------------------------
@@ -444,6 +448,8 @@
 void CflSubsampler_C(int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
                      const int max_luma_width, const int max_luma_height,
                      const void* const source, ptrdiff_t stride) {
+  assert(max_luma_width >= 4);
+  assert(max_luma_height >= 4);
   const auto* src = static_cast<const Pixel*>(source);
   stride /= sizeof(Pixel);
   int sum = 0;
@@ -506,8 +512,9 @@
   const int max_base_x = ((width + height) - 1) << upsample_shift;
   const int scale_bits = 6 - upsample_shift;
   const int base_step = 1 << upsample_shift;
-  for (int y = 0, top_x = xstep; y < height;
-       ++y, dst += stride, top_x += xstep) {
+  int top_x = xstep;
+  int y = 0;
+  do {
     int top_base_x = top_x >> scale_bits;
 
     if (top_base_x >= max_base_x) {
@@ -519,7 +526,8 @@
     }
 
     const int shift = ((top_x << upsample_shift) & 0x3F) >> 1;
-    for (int x = 0; x < width; ++x, top_base_x += base_step) {
+    int x = 0;
+    do {
       if (top_base_x >= max_base_x) {
         Memset(dst + x, top[max_base_x], width - x);
         break;
@@ -528,8 +536,12 @@
       const int val =
           top[top_base_x] * (32 - shift) + top[top_base_x + 1] * shift;
       dst[x] = RightShiftWithRounding(val, 5);
-    }
-  }
+      top_base_x += base_step;
+    } while (++x < width);
+
+    dst += stride;
+    top_x += xstep;
+  } while (++y < height);
 }
 
 template <typename Pixel>
@@ -554,11 +566,13 @@
   const int scale_bits_y = 6 - upsample_left_shift;
   const int min_base_x = -(1 << upsample_top_shift);
   const int base_step_x = 1 << upsample_top_shift;
-  for (int y = 0, top_x = -xstep; y < height;
-       ++y, top_x -= xstep, dst += stride) {
-    for (int x = 0, top_base_x = top_x >> scale_bits_x,
-             left_y = (y << 6) - ystep;
-         x < width; ++x, top_base_x += base_step_x, left_y -= ystep) {
+  int y = 0;
+  int top_x = -xstep;
+  do {
+    int top_base_x = top_x >> scale_bits_x;
+    int left_y = (y << 6) - ystep;
+    int x = 0;
+    do {
       int val;
       if (top_base_x >= min_base_x) {
         const int shift = ((top_x * (1 << upsample_top_shift)) & 0x3F) >> 1;
@@ -571,8 +585,13 @@
         val = left[left_base_y] * (32 - shift) + left[left_base_y + 1] * shift;
       }
       dst[x] = RightShiftWithRounding(val, 5);
-    }
-  }
+      top_base_x += base_step_x;
+      left_y -= ystep;
+    } while (++x < width);
+
+    top_x -= xstep;
+    dst += stride;
+  } while (++y < height);
 }
 
 template <typename Pixel>
@@ -594,17 +613,24 @@
          ((ystep * width) >> scale_bits) +
              base_step * (height - 1));  // left_base_y
 
-  for (int x = 0, left_y = ystep; x < width; ++x, left_y += ystep) {
+  int left_y = ystep;
+  int x = 0;
+  do {
     auto* dst = static_cast<Pixel*>(dest);
 
-    for (int y = 0, left_base_y = left_y >> scale_bits; y < height;
-         ++y, left_base_y += base_step, dst += stride) {
+    int left_base_y = left_y >> scale_bits;
+    int y = 0;
+    do {
       const int shift = ((left_y << upsample_shift) & 0x3F) >> 1;
       const int val =
           left[left_base_y] * (32 - shift) + left[left_base_y + 1] * shift;
       dst[x] = RightShiftWithRounding(val, 5);
-    }
-  }
+      dst += stride;
+      left_base_y += base_step;
+    } while (++y < height);
+
+    left_y += ystep;
+  } while (++x < width);
 }
 
 //------------------------------------------------------------------------------
diff --git a/libgav1/src/dsp/inverse_transform.cc b/libgav1/src/dsp/inverse_transform.cc
index bd9c324..1ce5e84 100644
--- a/libgav1/src/dsp/inverse_transform.cc
+++ b/libgav1/src/dsp/inverse_transform.cc
@@ -8,6 +8,7 @@
 #include "src/dsp/dsp.h"
 #include "src/utils/array_2d.h"
 #include "src/utils/common.h"
+#include "src/utils/compiler_attributes.h"
 #include "src/utils/logging.h"
 
 namespace libgav1 {
@@ -36,8 +37,9 @@
 }
 
 template <typename Residual>
-void ButterflyRotation_C(Residual* const dst, int a, int b, int angle,
-                         bool flip, int8_t range) {
+LIBGAV1_ALWAYS_INLINE void ButterflyRotation_C(Residual* const dst, int a,
+                                               int b, int angle, bool flip,
+                                               int8_t range) {
   // Note that we multiply in 32 bits and then add/subtract the products in 64
   // bits. The 32-bit multiplications do not overflow. Please see the comment
   // and assert() in Cos128().
diff --git a/libgav1/src/dsp/loop_restoration.cc b/libgav1/src/dsp/loop_restoration.cc
index 0a07c0e..76ef62a 100644
--- a/libgav1/src/dsp/loop_restoration.cc
+++ b/libgav1/src/dsp/loop_restoration.cc
@@ -22,6 +22,39 @@
 // Precision bits of generated values higher than source before projection.
 constexpr int kSgrProjRestoreBits = 4;
 
+// Section 7.17.3.
+// a2: range [1, 256].
+// if (z >= 255)
+//   a2 = 256;
+// else if (z == 0)
+//   a2 = 1;
+// else
+//   a2 = ((z << kSgrProjSgrBits) + (z >> 1)) / (z + 1);
+constexpr int kXByXPlus1[256] = {
+    1,   128, 171, 192, 205, 213, 219, 224, 228, 230, 233, 235, 236, 238, 239,
+    240, 241, 242, 243, 243, 244, 244, 245, 245, 246, 246, 247, 247, 247, 247,
+    248, 248, 248, 248, 249, 249, 249, 249, 249, 250, 250, 250, 250, 250, 250,
+    250, 251, 251, 251, 251, 251, 251, 251, 251, 251, 251, 252, 252, 252, 252,
+    252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 253, 253,
+    253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253,
+    253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 254, 254, 254,
+    254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
+    254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
+    254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
+    254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
+    254, 254, 254, 254, 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
+    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
+    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
+    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
+    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
+    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
+    256};
+
+constexpr int kOneByX[25] = {
+    4096, 2048, 1365, 1024, 819, 683, 585, 512, 455, 410, 372, 341, 315,
+    293,  273,  256,  241,  228, 216, 205, 195, 186, 178, 171, 164,
+};
+
 template <int bitdepth, typename Pixel>
 struct LoopRestorationFuncs_C {
   LoopRestorationFuncs_C() = delete;
@@ -210,16 +243,8 @@
       // (this holds even after accounting for the rounding in s)
       const uint32_t z = RightShiftWithRounding(p * s, kSgrProjScaleBits);
       // a2: range [1, 256].
-      uint32_t a2;
-      if (z >= 255) {
-        a2 = 256;
-      } else if (z == 0) {
-        a2 = 1;
-      } else {
-        a2 = ((z << kSgrProjSgrBits) + (z >> 1)) / (z + 1);
-      }
-      const uint32_t one_over_n =
-          ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n;
+      uint32_t a2 = kXByXPlus1[std::min(z, 255u)];
+      const uint32_t one_over_n = kOneByX[n - 1];
       // (kSgrProjSgrBits - a2) < 2^8, b < 2^(bitdepth) * n,
       // one_over_n = round(2^12 / n)
       // => the product here is < 2^(20 + bitdepth) <= 2^32,
diff --git a/libgav1/src/dsp/obmc.inc b/libgav1/src/dsp/obmc.inc
index 6da0001..e510226 100644
--- a/libgav1/src/dsp/obmc.inc
+++ b/libgav1/src/dsp/obmc.inc
@@ -14,5 +14,5 @@
     // Obmc Mask 16
     34, 37, 40, 43, 46, 49, 52, 54, 56, 58, 60, 61, 64, 64, 64, 64,
     // Obmc Mask 32
-    33, 35, 36, 38, 40, 41, 43, 44, 45, 47, 48, 50, 51, 52, 53, 55, 56, 57,
-    58, 59, 60, 60, 61, 62, 64, 64, 64, 64, 64, 64, 64, 64};
+    33, 35, 36, 38, 40, 41, 43, 44, 45, 47, 48, 50, 51, 52, 53, 55, 56, 57, 58,
+    59, 60, 60, 61, 62, 64, 64, 64, 64, 64, 64, 64, 64};
diff --git a/libgav1/src/dsp/warp.cc b/libgav1/src/dsp/warp.cc
index 60aab23..1dd904f 100644
--- a/libgav1/src/dsp/warp.cc
+++ b/libgav1/src/dsp/warp.cc
@@ -10,6 +10,7 @@
 #include "src/dsp/dsp.h"
 #include "src/utils/common.h"
 #include "src/utils/constants.h"
+#include "src/utils/memory.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -22,7 +23,7 @@
 void Warp_C(const void* const source, ptrdiff_t source_stride,
             const int source_width, const int source_height,
             const int* const warp_params, const int subsampling_x,
-            const int subsampling_y, const uint8_t inter_round_bits_vertical,
+            const int subsampling_y, const int inter_round_bits_vertical,
             const int block_start_x, const int block_start_y,
             const int block_width, const int block_height, const int16_t alpha,
             const int16_t beta, const int16_t gamma, const int16_t delta,
@@ -69,6 +70,31 @@
         // filtering.
         const int row = Clip3(iy4 + y, 0, source_height - 1);
         const Pixel* const src_row = src + row * source_stride;
+        // Check for two simple special cases.
+        if (ix4 - 7 >= source_width - 1) {
+          // Every sample is equal to src_row[source_width - 1]. Since the sum
+          // of the warped filter coefficients is 128 (= 2^7), the filtering is
+          // equivalent to multiplying src_row[source_width - 1] by 128.
+          const int s =
+              (horizontal_offset >> kInterRoundBitsHorizontal) +
+              (src_row[source_width - 1] << (7 - kInterRoundBitsHorizontal));
+          Memset(intermediate_result[y + 7], s, 8);
+          sx4 += beta;
+          continue;
+        }
+        if (ix4 + 7 <= 0) {
+          // Every sample is equal to src_row[0]. Since the sum of the warped
+          // filter coefficients is 128 (= 2^7), the filtering is equivalent to
+          // multiplying src_row[0] by 128.
+          const int s = (horizontal_offset >> kInterRoundBitsHorizontal) +
+                        (src_row[0] << (7 - kInterRoundBitsHorizontal));
+          Memset(intermediate_result[y + 7], s, 8);
+          sx4 += beta;
+          continue;
+        }
+        // At this point, we know ix4 - 7 < source_width - 1 and ix4 + 7 > 0.
+        // It follows that -6 <= ix4 <= source_width + 5. This inequality is
+        // used below.
         int sx = sx4 - MultiplyBy4(alpha);
         for (int x = -4; x < 4; ++x) {
           const int offset =
@@ -87,7 +113,15 @@
           // uint16_t.
           // For 10/12 bit, the range of sum is within 32 bits.
           for (int k = 0; k < 8; ++k) {
-            const int column = Clip3(ix4 + x + k - 3, 0, source_width - 1);
+            // We assume the source frame has left and right borders of at
+            // least 13 pixels that extend the frame boundary pixels.
+            //
+            // Since -4 <= x <= 3 and 0 <= k <= 7, using the inequality on ix4
+            // above, we have -13 <= ix4 + x + k - 3 <= source_width + 12, or
+            // -13 <= column <= (source_width - 1) + 13. Therefore we may
+            // over-read up to 13 pixels before the source row, or up to 13
+            // pixels after the source row.
+            const int column = ix4 + x + k - 3;
             sum += kWarpedFilters[offset][k] * src_row[column];
           }
           assert(sum >= 0 && sum < (horizontal_offset << 2));
diff --git a/libgav1/src/dsp/x86/common_sse4.h b/libgav1/src/dsp/x86/common_sse4.h
index 305837a..5d77216 100644
--- a/libgav1/src/dsp/x86/common_sse4.h
+++ b/libgav1/src/dsp/x86/common_sse4.h
@@ -156,6 +156,18 @@
   return _mm_srai_epi32(v_tmp_d, bits);
 }
 
+//------------------------------------------------------------------------------
+// Masking utilities
+inline __m128i MaskHighNBytes(int n) {
+  const uint8_t lu_table[32] = {
+      0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
+      0,   0,   0,   0,   0,   255, 255, 255, 255, 255, 255,
+      255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
+  };
+
+  return LoadUnaligned16(lu_table + n);
+}
+
 }  // namespace dsp
 }  // namespace libgav1
 
diff --git a/libgav1/src/dsp/x86/intrapred_cfl_sse4.cc b/libgav1/src/dsp/x86/intrapred_cfl_sse4.cc
index 4d97fb0..26c2ba0 100644
--- a/libgav1/src/dsp/x86/intrapred_cfl_sse4.cc
+++ b/libgav1/src/dsp/x86/intrapred_cfl_sse4.cc
@@ -13,6 +13,7 @@
 #include "src/dsp/constants.h"
 #include "src/dsp/x86/common_sse4.h"
 #include "src/utils/common.h"
+#include "src/utils/compiler_attributes.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -69,44 +70,52 @@
   } while ((row += (1 << kCflLumaBufferStrideLog2_128i)) < row_end);
 }
 
-template <int block_height_log2>
+template <int block_height_log2, bool is_inside>
 void CflSubsampler444_4xH_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
-    const int /*max_luma_width*/, const int /*max_luma_height*/,
+    const int /*max_luma_width*/, const int max_luma_height,
     const void* const source, ptrdiff_t stride) {
+  static_assert(block_height_log2 <= 4, "");
   const int block_height = 1 << block_height_log2;
+  const int visible_height = max_luma_height;
   const auto* src = static_cast<const uint8_t*>(source);
   __m128i sum = _mm_setzero_si128();
   int16_t* luma_ptr = luma[0];
   const __m128i zero = _mm_setzero_si128();
-  for (int y = 0; y < block_height; y += 4) {
-    __m128i samples01 = Load4(src);
+  __m128i samples;
+  int y = 0;
+  do {
+    samples = Load4(src);
     src += stride;
     int src_bytes;
     memcpy(&src_bytes, src, 4);
-    samples01 = _mm_insert_epi32(samples01, src_bytes, 1);
+    samples = _mm_insert_epi32(samples, src_bytes, 1);
     src += stride;
-    samples01 = _mm_slli_epi16(_mm_cvtepu8_epi16(samples01), 3);
-    StoreLo8(luma_ptr, samples01);
+    samples = _mm_slli_epi16(_mm_cvtepu8_epi16(samples), 3);
+    StoreLo8(luma_ptr, samples);
     luma_ptr += kCflLumaBufferStride;
-    StoreHi8(luma_ptr, samples01);
+    StoreHi8(luma_ptr, samples);
     luma_ptr += kCflLumaBufferStride;
 
-    __m128i samples23 = Load4(src);
-    src += stride;
-    memcpy(&src_bytes, src, 4);
-    samples23 = _mm_insert_epi32(samples23, src_bytes, 1);
-    src += stride;
-    samples23 = _mm_slli_epi16(_mm_cvtepu8_epi16(samples23), 3);
-    StoreLo8(luma_ptr, samples23);
-    luma_ptr += kCflLumaBufferStride;
-    StoreHi8(luma_ptr, samples23);
-    luma_ptr += kCflLumaBufferStride;
+    // The maximum value here is 2**bd * H * 2**shift. Since the maximum H for
+    // 4XH is 16 = 2**4, we have 2**(8 + 4 + 3) = 2**15, which fits in 16 bits.
+    sum = _mm_add_epi16(sum, samples);
+    y += 2;
+  } while (y < visible_height);
 
-    const __m128i sample_sum = _mm_add_epi16(samples01, samples23);
-    sum = _mm_add_epi32(sum, _mm_cvtepu16_epi32(sample_sum));
-    sum = _mm_add_epi32(sum, _mm_unpackhi_epi16(sample_sum, zero));
+  if (!is_inside) {
+    int y = visible_height;
+    do {
+      StoreHi8(luma_ptr, samples);
+      luma_ptr += kCflLumaBufferStride;
+      sum = _mm_add_epi16(sum, samples);
+      ++y;
+    } while (y < block_height);
   }
+
+  __m128i sum_tmp = _mm_unpackhi_epi16(sum, zero);
+  sum = _mm_cvtepu16_epi32(sum);
+  sum = _mm_add_epi32(sum, sum_tmp);
   sum = _mm_add_epi32(sum, _mm_srli_si128(sum, 8));
   sum = _mm_add_epi32(sum, _mm_srli_si128(sum, 4));
 
@@ -121,33 +130,93 @@
 }
 
 template <int block_height_log2>
+void CflSubsampler444_4xH_SSE4_1(
+    int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
+    const int max_luma_width, const int max_luma_height,
+    const void* const source, ptrdiff_t stride) {
+  static_assert(block_height_log2 <= 4, "");
+  assert(max_luma_width >= 4);
+  assert(max_luma_height >= 4);
+  const int block_height = 1 << block_height_log2;
+  const int block_width = 4;
+
+  if (block_height <= max_luma_height && block_width <= max_luma_width) {
+    CflSubsampler444_4xH_SSE4_1<block_height_log2, true>(
+        luma, max_luma_width, max_luma_height, source, stride);
+  } else {
+    CflSubsampler444_4xH_SSE4_1<block_height_log2, false>(
+        luma, max_luma_width, max_luma_height, source, stride);
+  }
+}
+
+template <int block_height_log2, bool inside>
 void CflSubsampler444_8xH_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
-    const int /*max_luma_width*/, const int /*max_luma_height*/,
+    const int max_luma_width, const int max_luma_height,
     const void* const source, ptrdiff_t stride) {
-  const int block_height = 1 << block_height_log2;
+  static_assert(block_height_log2 <= 5, "");
+  const int block_height = 1 << block_height_log2, block_width = 8;
+  const int visible_height = max_luma_height;
+  const int invisible_width = inside ? 0 : block_width - max_luma_width;
+  const int visible_width = max_luma_width;
+  const __m128i blend_mask =
+      inside ? _mm_setzero_si128() : MaskHighNBytes(8 + invisible_width);
   const __m128i dup16 = _mm_set1_epi32(0x01000100);
   const auto* src = static_cast<const uint8_t*>(source);
-  __m128i sum = _mm_setzero_si128();
   int16_t* luma_ptr = luma[0];
   const __m128i zero = _mm_setzero_si128();
-  for (int y = 0; y < block_height; y += 2) {
+  // Since the maximum height is 32, if we split them by parity, each one only
+  // needs to accumulate 16 rows. Just like the calculation done in 4XH, we can
+  // store them in 16 bits without casting to 32 bits.
+  __m128i sum_even = _mm_setzero_si128(), sum_odd = _mm_setzero_si128();
+  __m128i sum;
+  __m128i samples1;
+
+  int y = 0;
+  do {
     __m128i samples0 = LoadLo8(src);
+    if (!inside) {
+      const __m128i border0 = _mm_set1_epi8(src[visible_width - 1]);
+      samples0 = _mm_blendv_epi8(samples0, border0, blend_mask);
+    }
     src += stride;
     samples0 = _mm_slli_epi16(_mm_cvtepu8_epi16(samples0), 3);
     StoreUnaligned16(luma_ptr, samples0);
     luma_ptr += kCflLumaBufferStride;
 
-    __m128i samples1 = LoadLo8(src);
+    sum_even = _mm_add_epi16(sum_even, samples0);
+
+    samples1 = LoadLo8(src);
+    if (!inside) {
+      const __m128i border1 = _mm_set1_epi8(src[visible_width - 1]);
+      samples1 = _mm_blendv_epi8(samples1, border1, blend_mask);
+    }
     src += stride;
     samples1 = _mm_slli_epi16(_mm_cvtepu8_epi16(samples1), 3);
     StoreUnaligned16(luma_ptr, samples1);
     luma_ptr += kCflLumaBufferStride;
 
-    const __m128i sample_sum = _mm_add_epi16(samples0, samples1);
-    sum = _mm_add_epi32(sum, _mm_cvtepu16_epi32(sample_sum));
-    sum = _mm_add_epi32(sum, _mm_unpackhi_epi16(sample_sum, zero));
+    sum_odd = _mm_add_epi16(sum_odd, samples1);
+    y += 2;
+  } while (y < visible_height);
+
+  if (!inside) {
+    for (int y = visible_height; y < block_height; y += 2) {
+      sum_even = _mm_add_epi16(sum_even, samples1);
+      StoreUnaligned16(luma_ptr, samples1);
+      luma_ptr += kCflLumaBufferStride;
+
+      sum_odd = _mm_add_epi16(sum_odd, samples1);
+      StoreUnaligned16(luma_ptr, samples1);
+      luma_ptr += kCflLumaBufferStride;
+    }
   }
+
+  sum = _mm_add_epi32(_mm_unpackhi_epi16(sum_even, zero),
+                      _mm_cvtepu16_epi32(sum_even));
+  sum = _mm_add_epi32(sum, _mm_unpackhi_epi16(sum_odd, zero));
+  sum = _mm_add_epi32(sum, _mm_cvtepu16_epi32(sum_odd));
+
   sum = _mm_add_epi32(sum, _mm_srli_si128(sum, 8));
   sum = _mm_add_epi32(sum, _mm_srli_si128(sum, 4));
 
@@ -161,44 +230,137 @@
   }
 }
 
+template <int block_height_log2>
+void CflSubsampler444_8xH_SSE4_1(
+    int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
+    const int max_luma_width, const int max_luma_height,
+    const void* const source, ptrdiff_t stride) {
+  static_assert(block_height_log2 <= 5, "");
+  assert(max_luma_width >= 4);
+  assert(max_luma_height >= 4);
+  const int block_height = 1 << block_height_log2;
+  const int block_width = 8;
+
+  const int horz_inside = block_width <= max_luma_width;
+  const int vert_inside = block_height <= max_luma_height;
+  if (horz_inside && vert_inside) {
+    CflSubsampler444_8xH_SSE4_1<block_height_log2, true>(
+        luma, max_luma_width, max_luma_height, source, stride);
+  } else {
+    CflSubsampler444_8xH_SSE4_1<block_height_log2, false>(
+        luma, max_luma_width, max_luma_height, source, stride);
+  }
+}
+
 // This function will only work for block_width 16 and 32.
-template <int block_width_log2, int block_height_log2>
+template <int block_width_log2, int block_height_log2, bool inside>
 void CflSubsampler444_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
-    const int /*max_luma_width*/, const int /*max_luma_height*/,
+    const int max_luma_width, const int max_luma_height,
     const void* const source, ptrdiff_t stride) {
   static_assert(block_width_log2 == 4 || block_width_log2 == 5, "");
+  static_assert(block_height_log2 <= 5, "");
+  assert(max_luma_width >= 4);
+  assert(max_luma_height >= 4);
   const int block_height = 1 << block_height_log2;
   const int block_width = 1 << block_width_log2;
+
+  const int visible_height = max_luma_height;
+  const int visible_width_16 = inside ? 16 : std::min(16, max_luma_width);
+  const int invisible_width_16 = 16 - visible_width_16;
+  const __m128i blend_mask_16 = MaskHighNBytes(invisible_width_16);
+  const int visible_width_32 = inside ? 32 : max_luma_width;
+  const int invisible_width_32 = 32 - visible_width_32;
+  const __m128i blend_mask_32 =
+      MaskHighNBytes(std::min(16, invisible_width_32));
+
   const __m128i dup16 = _mm_set1_epi32(0x01000100);
   const __m128i zero = _mm_setzero_si128();
   const auto* src = static_cast<const uint8_t*>(source);
   int16_t* luma_ptr = luma[0];
   __m128i sum = _mm_setzero_si128();
-  for (int y = 0; y < block_height;
-       luma_ptr += kCflLumaBufferStride, src += stride, ++y) {
-    const __m128i samples01 = LoadUnaligned16(src);
-    const __m128i samples0 = _mm_slli_epi16(_mm_cvtepu8_epi16(samples01), 3);
-    const __m128i samples1 =
-        _mm_slli_epi16(_mm_unpackhi_epi8(samples01, zero), 3);
+
+  __m128i samples0, samples1;
+  __m128i samples2, samples3;
+  __m128i inner_sum_lo, inner_sum_hi;
+  int y = 0;
+  do {
+#if LIBGAV1_MSAN  // We can load unintialized values here. Even though they are
+                  // then masked off by blendv, MSAN isn't smart enough to
+                  // understand that. So we switch to a C implementation here.
+    uint16_t c_arr[16];
+    for (int x = 0; x < 16; x++) {
+      const int x_index = std::min(x, visible_width_16 - 1);
+      c_arr[x] = src[x_index] << 3;
+    }
+    samples0 = LoadUnaligned16(c_arr);
+    samples1 = LoadUnaligned16(c_arr + 8);
+    static_cast<void>(blend_mask_16);
+#else
+    __m128i samples01 = LoadUnaligned16(src);
+
+    if (!inside) {
+      const __m128i border16 = _mm_set1_epi8(src[visible_width_16 - 1]);
+      samples01 = _mm_blendv_epi8(samples01, border16, blend_mask_16);
+    }
+    samples0 = _mm_slli_epi16(_mm_cvtepu8_epi16(samples01), 3);
+    samples1 = _mm_slli_epi16(_mm_unpackhi_epi8(samples01, zero), 3);
+#endif  // LIBGAV1_MSAN
+
     StoreUnaligned16(luma_ptr, samples0);
     StoreUnaligned16(luma_ptr + 8, samples1);
     __m128i inner_sum = _mm_add_epi16(samples0, samples1);
+
     if (block_width == 32) {
-      const __m128i samples23 = LoadUnaligned16(src + 16);
-      const __m128i samples2 = _mm_slli_epi16(_mm_cvtepu8_epi16(samples23), 3);
-      const __m128i samples3 =
-          _mm_slli_epi16(_mm_unpackhi_epi8(samples23, zero), 3);
+#if LIBGAV1_MSAN  // We can load unintialized values here. Even though they are
+                  // then masked off by blendv, MSAN isn't smart enough to
+                  // understand that. So we switch to a C implementation here.
+      uint16_t c_arr[16];
+      for (int x = 16; x < 32; x++) {
+        const int x_index = std::min(x, visible_width_32 - 1);
+        c_arr[x - 16] = src[x_index] << 3;
+      }
+      samples2 = LoadUnaligned16(c_arr);
+      samples3 = LoadUnaligned16(c_arr + 8);
+      static_cast<void>(blend_mask_32);
+#else
+      __m128i samples23 = LoadUnaligned16(src + 16);
+      if (!inside) {
+        const __m128i border32 = _mm_set1_epi8(src[visible_width_32 - 1]);
+        samples23 = _mm_blendv_epi8(samples23, border32, blend_mask_32);
+      }
+      samples2 = _mm_slli_epi16(_mm_cvtepu8_epi16(samples23), 3);
+      samples3 = _mm_slli_epi16(_mm_unpackhi_epi8(samples23, zero), 3);
+#endif  // LIBGAV1_MSAN
+
       StoreUnaligned16(luma_ptr + 16, samples2);
       StoreUnaligned16(luma_ptr + 24, samples3);
       inner_sum = _mm_add_epi16(samples2, inner_sum);
       inner_sum = _mm_add_epi16(samples3, inner_sum);
     }
-    const __m128i inner_sum_lo = _mm_cvtepu16_epi32(inner_sum);
-    const __m128i inner_sum_hi = _mm_unpackhi_epi16(inner_sum, zero);
+
+    inner_sum_lo = _mm_cvtepu16_epi32(inner_sum);
+    inner_sum_hi = _mm_unpackhi_epi16(inner_sum, zero);
     sum = _mm_add_epi32(sum, inner_sum_lo);
     sum = _mm_add_epi32(sum, inner_sum_hi);
+    luma_ptr += kCflLumaBufferStride;
+    src += stride;
+  } while (++y < visible_height);
+
+  if (!inside) {
+    for (int y = visible_height; y < block_height;
+         luma_ptr += kCflLumaBufferStride, ++y) {
+      sum = _mm_add_epi32(sum, inner_sum_lo);
+      StoreUnaligned16(luma_ptr, samples0);
+      sum = _mm_add_epi32(sum, inner_sum_hi);
+      StoreUnaligned16(luma_ptr + 8, samples1);
+      if (block_width == 32) {
+        StoreUnaligned16(luma_ptr + 16, samples2);
+        StoreUnaligned16(luma_ptr + 24, samples3);
+      }
+    }
   }
+
   sum = _mm_add_epi32(sum, _mm_srli_si128(sum, 8));
   sum = _mm_add_epi32(sum, _mm_srli_si128(sum, 4));
 
@@ -214,6 +376,29 @@
   }
 }
 
+template <int block_width_log2, int block_height_log2>
+void CflSubsampler444_SSE4_1(
+    int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
+    const int max_luma_width, const int max_luma_height,
+    const void* const source, ptrdiff_t stride) {
+  static_assert(block_width_log2 == 4 || block_width_log2 == 5, "");
+  static_assert(block_height_log2 <= 5, "");
+  assert(max_luma_width >= 4);
+  assert(max_luma_height >= 4);
+
+  const int block_height = 1 << block_height_log2;
+  const int block_width = 1 << block_width_log2;
+  const int horz_inside = block_width <= max_luma_width;
+  const int vert_inside = block_height <= max_luma_height;
+  if (horz_inside && vert_inside) {
+    CflSubsampler444_SSE4_1<block_width_log2, block_height_log2, true>(
+        luma, max_luma_width, max_luma_height, source, stride);
+  } else {
+    CflSubsampler444_SSE4_1<block_width_log2, block_height_log2, false>(
+        luma, max_luma_width, max_luma_height, source, stride);
+  }
+}
+
 // Takes in two sums of input row pairs, and completes the computation for two
 // output rows.
 inline __m128i StoreLumaResults4_420(const __m128i vertical_sum0,
@@ -640,8 +825,7 @@
   dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType420] =
       CflSubsampler420_WxH_SSE4_1<5, 5>;
 #endif
-  // TODO(b/137035169): enable these once test vector mismatches are fixed.
-#if 0
+
 #if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x4_CflSubsampler444)
   dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType444] =
       CflSubsampler444_4xH_SSE4_1<2>;
@@ -698,7 +882,6 @@
   dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType444] =
       CflSubsampler444_SSE4_1<5, 5>;
 #endif
-#endif
 #if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x4_CflIntraPredictor)
   dsp->cfl_intra_predictors[kTransformSize4x4] = CflIntraPredictor_SSE4_1<4, 4>;
 #endif
diff --git a/libgav1/src/dsp/x86/intrapred_sse4.h b/libgav1/src/dsp/x86/intrapred_sse4.h
index ec5dfb0..c44c533 100644
--- a/libgav1/src/dsp/x86/intrapred_sse4.h
+++ b/libgav1/src/dsp/x86/intrapred_sse4.h
@@ -178,8 +178,6 @@
 #define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler420 LIBGAV1_DSP_SSE4_1
 #endif
 
-// TODO(b/137035169): enable these once test vector mismatches are fixed.
-#if 0
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler444
 #define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler444 LIBGAV1_DSP_SSE4_1
 #endif
@@ -235,7 +233,6 @@
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler444
 #define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler444 LIBGAV1_DSP_SSE4_1
 #endif
-#endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_CflIntraPredictor
 #define LIBGAV1_Dsp8bpp_TransformSize4x4_CflIntraPredictor LIBGAV1_DSP_SSE4_1
diff --git a/libgav1/src/post_filter.cc b/libgav1/src/post_filter.cc
index 107a04c..761f1b6 100644
--- a/libgav1/src/post_filter.cc
+++ b/libgav1/src/post_filter.cc
@@ -11,6 +11,7 @@
 #include "src/dsp/constants.h"
 #include "src/utils/array_2d.h"
 #include "src/utils/blocking_counter.h"
+#include "src/utils/constants.h"
 #include "src/utils/logging.h"
 #include "src/utils/memory.h"
 #include "src/utils/types.h"
@@ -113,6 +114,8 @@
     const int plane_width =
         RightShiftWithRounding(upscaled_width_, subsampling_x);
     const int plane_height = RightShiftWithRounding(height_, subsampling_y);
+    assert(source_buffer_->left_border(plane) >= kMinLeftBorderPixels &&
+           source_buffer_->right_border(plane) >= kMinRightBorderPixels);
     ExtendFrameBoundary(
         source_buffer_->data(plane), plane_width, plane_height,
         source_buffer_->stride(plane), source_buffer_->left_border(plane),
diff --git a/libgav1/src/prediction_mask.cc b/libgav1/src/prediction_mask.cc
index 8a2a63f..2ea8e0c 100644
--- a/libgav1/src/prediction_mask.cc
+++ b/libgav1/src/prediction_mask.cc
@@ -136,33 +136,6 @@
   return kWedgeCodebook[BlockShape(block_size)][index][2];
 }
 
-void DifferenceWeightMask(const uint16_t* prediction_1,
-                          const ptrdiff_t stride_1,
-                          const uint16_t* prediction_2,
-                          const ptrdiff_t stride_2, const int bitdepth,
-                          const bool mask_is_inverse, const int width,
-                          const int height, uint8_t* mask,
-                          const ptrdiff_t mask_stride) {
-#if LIBGAV1_MAX_BITDEPTH == 12
-  const int inter_post_round_bits = (bitdepth == 12) ? 2 : 4;
-#else
-  constexpr int inter_post_round_bits = 4;
-#endif
-  for (int y = 0; y < height; ++y) {
-    for (int x = 0; x < width; ++x) {
-      const int rounding_bits = bitdepth - 8 + inter_post_round_bits;
-      const int difference = RightShiftWithRounding(
-          std::abs(prediction_1[x] - prediction_2[x]), rounding_bits);
-      const auto mask_value =
-          static_cast<uint8_t>(std::min(DivideBy16(difference) + 38, 64));
-      mask[x] = mask_is_inverse ? 64 - mask_value : mask_value;
-    }
-    prediction_1 += stride_1;
-    prediction_2 += stride_2;
-    mask += mask_stride;
-  }
-}
-
 }  // namespace
 
 void GenerateWedgeMask(uint8_t* const wedge_master_mask_data,
@@ -284,14 +257,29 @@
   }
 }
 
-void GenerateWeightMask(const uint16_t* const prediction_1,
-                        const ptrdiff_t stride_1,
-                        const uint16_t* const prediction_2,
-                        const ptrdiff_t stride_2, const bool mask_is_inverse,
-                        const int width, const int height, const int bitdepth,
-                        uint8_t* const mask, const ptrdiff_t mask_stride) {
-  DifferenceWeightMask(prediction_1, stride_1, prediction_2, stride_2, bitdepth,
-                       mask_is_inverse, width, height, mask, mask_stride);
+void GenerateWeightMask(const uint16_t* prediction_1, const ptrdiff_t stride_1,
+                        const uint16_t* prediction_2, const ptrdiff_t stride_2,
+                        const bool mask_is_inverse, const int width,
+                        const int height, const int bitdepth, uint8_t* mask,
+                        const ptrdiff_t mask_stride) {
+#if LIBGAV1_MAX_BITDEPTH == 12
+  const int inter_post_round_bits = (bitdepth == 12) ? 2 : 4;
+#else
+  constexpr int inter_post_round_bits = 4;
+#endif
+  for (int y = 0; y < height; ++y) {
+    for (int x = 0; x < width; ++x) {
+      const int rounding_bits = bitdepth - 8 + inter_post_round_bits;
+      const int difference = RightShiftWithRounding(
+          std::abs(prediction_1[x] - prediction_2[x]), rounding_bits);
+      const auto mask_value =
+          static_cast<uint8_t>(std::min(DivideBy16(difference) + 38, 64));
+      mask[x] = mask_is_inverse ? 64 - mask_value : mask_value;
+    }
+    prediction_1 += stride_1;
+    prediction_2 += stride_2;
+    mask += mask_stride;
+  }
 }
 
 void GenerateInterIntraMask(const int mode, const int width, const int height,
diff --git a/libgav1/src/threading_strategy.cc b/libgav1/src/threading_strategy.cc
index 440cad0..9c2d54d 100644
--- a/libgav1/src/threading_strategy.cc
+++ b/libgav1/src/threading_strategy.cc
@@ -3,15 +3,10 @@
 #include <algorithm>
 #include <cassert>
 
+#include "src/utils/constants.h"
 #include "src/utils/logging.h"
 
 namespace libgav1 {
-namespace {
-
-// Maximum number of threads that the library will ever create.
-constexpr int kMaxThreads = 32;
-
-}  // namespace
 
 bool ThreadingStrategy::Reset(const ObuFrameHeader& frame_header,
                               int thread_count) {
@@ -25,7 +20,7 @@
 
   // We do work in the current thread, so it is sufficient to create
   // |thread_count|-1 threads in the threadpool.
-  thread_count = std::min(thread_count - 1, kMaxThreads);
+  thread_count = std::min(thread_count - 1, static_cast<int>(kMaxThreads));
 
   if (thread_pool_ == nullptr || thread_pool_->num_threads() != thread_count) {
     thread_pool_ = ThreadPool::Create("libgav1", thread_count);
diff --git a/libgav1/src/tile.h b/libgav1/src/tile.h
index 3895f5e..fd7b7c0 100644
--- a/libgav1/src/tile.h
+++ b/libgav1/src/tile.h
@@ -12,6 +12,7 @@
 #include <vector>
 
 #include "src/buffer_pool.h"
+#include "src/decoder_scratch_buffer.h"
 #include "src/dsp/common.h"
 #include "src/dsp/constants.h"
 #include "src/dsp/dsp.h"
@@ -66,6 +67,7 @@
        BlockParametersHolder* block_parameters, Array2D<int16_t>* cdef_index,
        Array2D<TransformSize>* inter_transform_sizes, const dsp::Dsp* dsp,
        ThreadPool* thread_pool, ResidualBufferPool* residual_buffer_pool,
+       DecoderScratchBufferPool* decoder_scratch_buffer_pool,
        BlockingCounterWithStatus* pending_tiles);
 
   // Move only.
@@ -111,44 +113,6 @@
     int depth;
   };
 
-  static constexpr int kBlockDecodedStride = 34;
-  // Buffer to facilitate decoding a superblock. When |split_parse_and_decode_|
-  // is true, each superblock that is being decoded will get its own instance of
-  // this buffer.
-  struct SuperBlockBuffer {
-    // The size of mask is 128x128.
-    AlignedUniquePtr<uint8_t> prediction_mask;
-    // Buffer used for inter prediction process. The buffers have an alignment
-    // of 8 bytes when allocated.
-    AlignedUniquePtr<uint16_t> prediction_buffer[2];
-    size_t prediction_buffer_size[2] = {};
-    // Equivalent to BlockDecoded array in the spec. This stores the decoded
-    // state of every 4x4 block in a superblock. It has 1 row/column border on
-    // all 4 sides (hence the 34x34 dimension instead of 32x32). Note that the
-    // spec uses "-1" as an index to access the left and top borders. In the
-    // code, we treat the index (1, 1) as equivalent to the spec's (0, 0). So
-    // all accesses into this array will be offset by +1 when compared with the
-    // spec.
-    bool block_decoded[kMaxPlanes][kBlockDecodedStride][kBlockDecodedStride];
-    // Buffer used for storing subsampled luma samples needed for CFL
-    // prediction. This buffer is used to avoid repetition of the subsampling
-    // for the V plane when it is already done for the U plane.
-    int16_t cfl_luma_buffer[kCflLumaBufferStride][kCflLumaBufferStride];
-    bool cfl_luma_buffer_valid;
-    // The |residual| pointer is used to traverse the |residual_buffer_|. It is
-    // used in two different ways.
-    // If |split_parse_and_decode_| is true:
-    //    |residual| points to the beginning of the |residual_buffer_| when the
-    //    "parse" and "decode" steps begin. It is then moved forward tx_size in
-    //    each iteration of the "parse" and the "decode" steps.
-    // If |split_parse_and_decode_| is false:
-    //    |residual| is reset to the beginning of the |residual_buffer_| for
-    //    every transform block.
-    uint8_t* residual;
-    // This queue is only used when |split_parse_and_decode_| is true.
-    TransformParameterQueue* transform_parameters;
-  };
-
   // Enum to track the processing state of a superblock.
   enum SuperBlockState : uint8_t {
     kSuperBlockStateNone,       // Not yet parsed or decoded.
@@ -169,6 +133,20 @@
     std::condition_variable pending_jobs_zero_condvar;
   };
 
+  // The residual pointer is used to traverse the |residual_buffer_|. It is
+  // used in two different ways.
+  // If |split_parse_and_decode_| is true:
+  //    The pointer points to the beginning of the |residual_buffer_| when the
+  //    "parse" and "decode" steps begin. It is then moved forward tx_size in
+  //    each iteration of the "parse" and the "decode" steps. In this case, the
+  //    ResidualPtr variable passed into various functions starting from
+  //    ProcessSuperBlock is used as an in/out parameter to keep track of the
+  //    residual pointer.
+  // If |split_parse_and_decode_| is false:
+  //    The pointer is reset to the beginning of the |residual_buffer_| for
+  //    every transform block.
+  using ResidualPtr = uint8_t*;
+
   // Performs member initializations that may fail. Called by Decode().
   LIBGAV1_MUST_USE_RESULT bool Init();
 
@@ -200,23 +178,28 @@
   // the blocks in the right order.
   bool ProcessPartition(
       int row4x4_start, int column4x4_start, ParameterTree* root,
-      SuperBlockBuffer* sb_buffer);  // Iterative implementation of 5.11.4.
+      DecoderScratchBuffer* scratch_buffer,
+      ResidualPtr* residual);  // Iterative implementation of 5.11.4.
   bool ProcessBlock(int row4x4, int column4x4, BlockSize block_size,
-                    ParameterTree* tree,
-                    SuperBlockBuffer* sb_buffer);  // 5.11.5.
-  void ResetCdef(int row4x4, int column4x4);       // 5.11.55.
+                    ParameterTree* tree, DecoderScratchBuffer* scratch_buffer,
+                    ResidualPtr* residual);   // 5.11.5.
+  void ResetCdef(int row4x4, int column4x4);  // 5.11.55.
 
   // This function is used to decode a superblock when the parsing has already
   // been done for that superblock.
-  bool DecodeSuperBlock(ParameterTree* tree, SuperBlockBuffer* sb_buffer);
+  bool DecodeSuperBlock(ParameterTree* tree,
+                        DecoderScratchBuffer* scratch_buffer,
+                        ResidualPtr* residual);
   // Helper function used by DecodeSuperBlock(). Note that the decode_block()
   // function in the spec is equivalent to ProcessBlock() in the code.
-  bool DecodeBlock(ParameterTree* tree, SuperBlockBuffer* sb_buffer);
+  bool DecodeBlock(ParameterTree* tree, DecoderScratchBuffer* scratch_buffer,
+                   ResidualPtr* residual);
 
-  void ClearBlockDecoded(SuperBlockBuffer* sb_buffer, int row4x4,
+  void ClearBlockDecoded(DecoderScratchBuffer* scratch_buffer, int row4x4,
                          int column4x4);  // 5.11.3.
   bool ProcessSuperBlock(int row4x4, int column4x4, int block_width4x4,
-                         SuperBlockBuffer* sb_buffer, ProcessingMode mode);
+                         DecoderScratchBuffer* scratch_buffer,
+                         ProcessingMode mode);
   void ResetLoopRestorationParams();
   void ReadLoopRestorationCoefficients(int row4x4, int column4x4,
                                        BlockSize block_size);  // 5.11.57.
@@ -251,7 +234,6 @@
   int GetPaletteCache(const Block& block, PlaneType plane_type,
                       uint16_t* cache);
   void ReadPaletteColors(const Block& block, Plane plane);
-  int GetHasPaletteYContext(const Block& block) const;
   void ReadPaletteModeInfo(const Block& block);      // 5.11.46.
   void ReadFilterIntraModeInfo(const Block& block);  // 5.11.24.
   int ReadMotionVectorComponent(const Block& block,
@@ -305,7 +287,7 @@
   void ReadVariableTransformTree(const Block& block, int row4x4, int column4x4,
                                  TransformSize tx_size);
   void DecodeTransformSize(const Block& block);  // 5.11.16.
-  bool ComputePrediction(const Block& block);    // 5.11.33.
+  void ComputePrediction(const Block& block);    // 5.11.33.
   // |x4| and |y4| are the column and row positions of the 4x4 block. |w4| and
   // |h4| are the width and height in 4x4 units of |tx_size|.
   int GetTransformAllZeroContext(const Block& block, Plane plane,
@@ -319,33 +301,32 @@
   void ReadTransformType(const Block& block, int x4, int y4,
                          TransformSize tx_size);  // 5.11.47.
   int GetCoeffBaseContextEob(TransformSize tx_size, int index);
-  int GetCoeffBaseContext2D(TransformSize tx_size, int adjusted_tx_width_log2,
+  int GetCoeffBaseContext2D(const int32_t* quantized_buffer,
+                            TransformSize tx_size, int adjusted_tx_width_log2,
                             uint16_t pos);
-  int GetCoeffBaseContextHorizontal(TransformSize tx_size,
+  int GetCoeffBaseContextHorizontal(const int32_t* quantized_buffer,
+                                    TransformSize tx_size,
                                     int adjusted_tx_width_log2, uint16_t pos);
-  int GetCoeffBaseContextVertical(TransformSize tx_size,
+  int GetCoeffBaseContextVertical(const int32_t* quantized_buffer,
+                                  TransformSize tx_size,
                                   int adjusted_tx_width_log2, uint16_t pos);
-  int GetCoeffBaseRangeContext2D(int adjusted_tx_width_log2, int pos);
-  int GetCoeffBaseRangeContextHorizontal(int adjusted_tx_width_log2, int pos);
-  int GetCoeffBaseRangeContextVertical(int adjusted_tx_width_log2, int pos);
+  int GetCoeffBaseRangeContext2D(const int32_t* quantized_buffer,
+                                 int adjusted_tx_width_log2, int pos);
+  int GetCoeffBaseRangeContextHorizontal(const int32_t* quantized_buffer,
+                                         int adjusted_tx_width_log2, int pos);
+  int GetCoeffBaseRangeContextVertical(const int32_t* quantized_buffer,
+                                       int adjusted_tx_width_log2, int pos);
   int GetDcSignContext(int x4, int y4, int w4, int h4, Plane plane);
   void SetEntropyContexts(int x4, int y4, int w4, int h4, Plane plane,
                           uint8_t coefficient_level, int8_t dc_category);
-  bool InterIntraPrediction(
+  void InterIntraPrediction(
       uint16_t* prediction[2], ptrdiff_t prediction_stride,
       const uint8_t* prediction_mask, ptrdiff_t prediction_mask_stride,
       const PredictionParameters& prediction_parameters, int prediction_width,
       int prediction_height, int subsampling_x, int subsampling_y,
       uint8_t* dest,
       ptrdiff_t dest_stride);  // Part of section 7.11.3.1 in the spec.
-  // Several prediction modes need a prediction mask:
-  // kCompoundPredictionTypeDiffWeighted, kCompoundPredictionTypeWedge,
-  // kCompoundPredictionTypeIntra. They are mutually exclusive. So the mask is
-  // allocated in each case. The mask only needs to be allocated for kPlaneY
-  // and then used for other planes.
-  LIBGAV1_MUST_USE_RESULT static bool AllocatePredictionMask(
-      SuperBlockBuffer* sb_buffer);
-  bool CompoundInterPrediction(
+  void CompoundInterPrediction(
       const Block& block, ptrdiff_t prediction_stride,
       ptrdiff_t prediction_mask_stride, int prediction_width,
       int prediction_height, Plane plane, int subsampling_x, int subsampling_y,
@@ -359,7 +340,7 @@
                               GlobalMotion* global_motion_params,
                               GlobalMotion* local_warp_params)
       const;  // Part of section 7.11.3.1 in the spec.
-  bool InterPrediction(const Block& block, Plane plane, int x, int y,
+  void InterPrediction(const Block& block, Plane plane, int x, int y,
                        int prediction_width, int prediction_height,
                        int candidate_row, int candidate_column,
                        bool* is_local_valid,
@@ -382,27 +363,27 @@
                           int step_y, int ref_block_start_x,
                           int ref_block_end_x, int ref_block_start_y,
                           uint8_t* block_buffer, ptrdiff_t block_stride);
-  void BlockInterPrediction(Plane plane, int reference_frame_index,
-                            const MotionVector& mv, int x, int y, int width,
-                            int height, int candidate_row, int candidate_column,
+  void BlockInterPrediction(const Block& block, Plane plane,
+                            int reference_frame_index, const MotionVector& mv,
+                            int x, int y, int width, int height,
+                            int candidate_row, int candidate_column,
                             uint16_t* prediction, ptrdiff_t prediction_stride,
-                            const uint8_t* round_bits, bool is_compound,
+                            int round_bits, bool is_compound,
                             bool is_inter_intra, uint8_t* dest,
                             ptrdiff_t dest_stride);  // 7.11.3.4.
-  bool BlockWarpProcess(const Block& block, Plane plane, int index,
+  void BlockWarpProcess(const Block& block, Plane plane, int index,
                         int block_start_x, int block_start_y, int width,
                         int height, ptrdiff_t prediction_stride,
-                        GlobalMotion* warp_params, const uint8_t* round_bits,
+                        GlobalMotion* warp_params, int round_bits,
                         bool is_compound, bool is_inter_intra, uint8_t* dest,
                         ptrdiff_t dest_stride);  // 7.11.3.5.
-  void ObmcBlockPrediction(const MotionVector& mv, Plane plane,
-                           int reference_frame_index, int width, int height,
-                           int x, int y, int candidate_row,
+  void ObmcBlockPrediction(const Block& block, const MotionVector& mv,
+                           Plane plane, int reference_frame_index, int width,
+                           int height, int x, int y, int candidate_row,
                            int candidate_column,
-                           ObmcDirection blending_direction,
-                           const uint8_t* round_bits);
+                           ObmcDirection blending_direction, int round_bits);
   void ObmcPrediction(const Block& block, Plane plane, int width, int height,
-                      const uint8_t* round_bits);  // 7.11.3.9.
+                      int round_bits);  // 7.11.3.9.
   void DistanceWeightedPrediction(uint16_t* prediction_0,
                                   ptrdiff_t prediction_stride_0,
                                   uint16_t* prediction_1,
@@ -418,8 +399,8 @@
   // coefficient level.
   template <bool is_dc_coefficient>
   bool ReadSignAndApplyDequantization(
-      const Block& block, const uint16_t* scan, int i,
-      int adjusted_tx_width_log2, int tx_width, int q_value,
+      const Block& block, int32_t* quantized_buffer, const uint16_t* scan,
+      int i, int adjusted_tx_width_log2, int tx_width, int q_value,
       const uint8_t* quantizer_matrix, int shift, int min_value, int max_value,
       uint16_t* dc_sign_cdf, int8_t* dc_category,
       int* coefficient_level);  // Part of 5.11.39.
@@ -450,8 +431,8 @@
   void PopulatePaletteColorContexts(
       const Block& block, PlaneType plane_type, int i, int start, int end,
       uint8_t color_order[kMaxPaletteSquare][kMaxPaletteSize],
-      uint8_t color_context[kMaxPaletteSquare]);                 // 5.11.50.
-  bool ReadPaletteTokens(const Block& block);                    // 5.11.49.
+      uint8_t color_context[kMaxPaletteSquare]);  // 5.11.50.
+  bool ReadPaletteTokens(const Block& block);     // 5.11.49.
   template <typename Pixel>
   void IntraPrediction(const Block& block, Plane plane, int x, int y,
                        bool has_left, bool has_top, bool has_top_right,
@@ -478,6 +459,20 @@
   // for the given |block| and stores them into |current_frame_|.
   void StoreMotionFieldMvsIntoCurrentFrame(const Block& block);
 
+  // Returns the zero-based index of the super block that contains |row4x4|
+  // relative to the start of this tile.
+  int SuperBlockRowIndex(int row4x4) const {
+    return (row4x4 - row4x4_start_) >>
+           (sequence_header_.use_128x128_superblock ? 5 : 4);
+  }
+
+  // Returns the zero-based index of the super block that contains |column4x4|
+  // relative to the start of this tile.
+  int SuperBlockColumnIndex(int column4x4) const {
+    return (column4x4 - column4x4_start_) >>
+           (sequence_header_.use_128x128_superblock ? 5 : 4);
+  }
+
   BlockSize SuperBlockSize() const {
     return sequence_header_.use_128x128_superblock ? kBlock128x128
                                                    : kBlock64x64;
@@ -556,24 +551,6 @@
   PostFilter& post_filter_;
   BlockParametersHolder& block_parameters_holder_;
   Quantizer quantizer_;
-  // The |quantized_| array is used by ReadTransformCoefficients() to store the
-  // quantized coefficients until the dequantization process is performed. This
-  // is declared as a class variable because only the first few values of this
-  // array will be used by each call to ReadTransformCoefficients() depending on
-  // the transform size.
-  int32_t quantized_[kQuantizedCoefficientBufferSize];
-  // Stores the "color order" for a block for each iteration in
-  // ReadPaletteTokens(). The "color order" is used to populate the
-  // |color_index_map| used for palette prediction. This is declared as a class
-  // variable because only the first few values in each dimension are used by
-  // each call depending on the block size and the palette size.
-  uint8_t color_order_[kMaxPaletteSquare][kMaxPaletteSize];
-  // Stores the "color context" for a block for each iteration in
-  // ReadPaletteTokens(). The "color context" is the cdf context index used to
-  // read the |palette_color_idx_y| variable in the spec. This is declared as a
-  // class variable because only the first few values in each dimension are used
-  // by each call depending on the block size and the palette size.
-  uint8_t color_context_[kMaxPaletteSquare];
   // When there is no multi-threading within the Tile, |residual_buffer_| is
   // used. When there is multi-threading within the Tile,
   // |residual_buffer_threaded_| is used. In the following comment,
@@ -622,6 +599,7 @@
   ThreadPool* const thread_pool_;
   ThreadingParameters threading_;
   ResidualBufferPool* const residual_buffer_pool_;
+  DecoderScratchBufferPool* const decoder_scratch_buffer_pool_;
   BlockingCounterWithStatus* const pending_tiles_;
   bool split_parse_and_decode_;
   // This is used only when |split_parse_and_decode_| is false.
@@ -639,13 +617,17 @@
 
 struct Tile::Block {
   Block(const Tile& tile, int row4x4, int column4x4, BlockSize size,
-        SuperBlockBuffer* const sb_buffer, BlockParameters* const parameters)
+        DecoderScratchBuffer* const scratch_buffer, ResidualPtr* residual,
+        BlockParameters* const parameters)
       : tile(tile),
         row4x4(row4x4),
         column4x4(column4x4),
         size(size),
         left_available(tile.IsInside(row4x4, column4x4 - 1)),
         top_available(tile.IsInside(row4x4 - 1, column4x4)),
+        residual_size{kPlaneResidualSize[size][0][0],
+                      kPlaneResidualSize[size][tile.subsampling_x_[kPlaneU]]
+                                        [tile.subsampling_y_[kPlaneU]]},
         bp_top(top_available
                    ? tile.block_parameters_holder_.Find(row4x4 - 1, column4x4)
                    : nullptr),
@@ -653,8 +635,13 @@
                     ? tile.block_parameters_holder_.Find(row4x4, column4x4 - 1)
                     : nullptr),
         bp(parameters),
-        sb_buffer(sb_buffer) {
+        scratch_buffer(scratch_buffer),
+        residual(residual) {
     assert(size != kBlockInvalid);
+    assert(residual_size[kPlaneTypeY] != kBlockInvalid);
+    if (tile.PlaneCount() > 1) {
+      assert(residual_size[kPlaneTypeUV] != kBlockInvalid);
+    }
   }
 
   bool HasChroma() const {
@@ -756,10 +743,12 @@
   const BlockSize size;
   const bool left_available;
   const bool top_available;
+  const BlockSize residual_size[kNumPlaneTypes];
   BlockParameters* const bp_top;
   BlockParameters* const bp_left;
   BlockParameters* const bp;
-  SuperBlockBuffer* const sb_buffer;
+  DecoderScratchBuffer* const scratch_buffer;
+  ResidualPtr* const residual;
 };
 
 }  // namespace libgav1
diff --git a/libgav1/src/tile/bitstream/mode_info.cc b/libgav1/src/tile/bitstream/mode_info.cc
index b2e6039..bdf92bf 100644
--- a/libgav1/src/tile/bitstream/mode_info.cc
+++ b/libgav1/src/tile/bitstream/mode_info.cc
@@ -48,6 +48,19 @@
   kCflSignPositive = 2
 };
 
+// For each possible value of the combined signs (which is read from the
+// bitstream), this array stores the following: sign_u, sign_v, alpha_u_context,
+// alpha_v_context. Only positive entries are used. Entry at index i is computed
+// as follows:
+// sign_u = i / 3
+// sign_v = i % 3
+// alpha_u_context = i - 2
+// alpha_v_context = (sign_v - 1) * 3 + sign_u
+constexpr int8_t kCflAlphaLookup[kCflAlphaSignsSymbolCount][4] = {
+    {0, 1, -2, 0}, {0, 2, -1, 3}, {1, 0, 0, -2}, {1, 1, 1, 1},
+    {1, 2, 2, 4},  {2, 0, 3, -1}, {2, 1, 4, 2},  {2, 2, 5, 5},
+};
+
 constexpr BitMaskSet kPredictionModeHasNearMvMask(kPredictionModeNearMv,
                                                   kPredictionModeNearNearMv,
                                                   kPredictionModeNearNewMv,
@@ -342,24 +355,26 @@
 void Tile::ReadCflAlpha(const Block& block) {
   const int signs = reader_.ReadSymbol<kCflAlphaSignsSymbolCount>(
       symbol_decoder_context_.cfl_alpha_signs_cdf);
-  const auto sign_u = static_cast<CflSign>((signs + 1) / 3);
-  const auto sign_v = static_cast<CflSign>((signs + 1) % 3);
+  const int8_t* const cfl_lookup = kCflAlphaLookup[signs];
+  const auto sign_u = static_cast<CflSign>(cfl_lookup[0]);
+  const auto sign_v = static_cast<CflSign>(cfl_lookup[1]);
   PredictionParameters& prediction_parameters =
       *block.bp->prediction_parameters;
   prediction_parameters.cfl_alpha_u = 0;
   if (sign_u != kCflSignZero) {
+    assert(cfl_lookup[2] >= 0);
     prediction_parameters.cfl_alpha_u =
         reader_.ReadSymbol<kCflAlphaSymbolCount>(
-            symbol_decoder_context_.cfl_alpha_cdf[signs - 2]) +
+            symbol_decoder_context_.cfl_alpha_cdf[cfl_lookup[2]]) +
         1;
     if (sign_u == kCflSignNegative) prediction_parameters.cfl_alpha_u *= -1;
   }
   prediction_parameters.cfl_alpha_v = 0;
   if (sign_v != kCflSignZero) {
-    const int context = (sign_v - 1) * 3 + sign_u;
+    assert(cfl_lookup[3] >= 0);
     prediction_parameters.cfl_alpha_v =
         reader_.ReadSymbol<kCflAlphaSymbolCount>(
-            symbol_decoder_context_.cfl_alpha_cdf[context]) +
+            symbol_decoder_context_.cfl_alpha_cdf[cfl_lookup[3]]) +
         1;
     if (sign_v == kCflSignNegative) prediction_parameters.cfl_alpha_v *= -1;
   }
@@ -369,9 +384,7 @@
   BlockParameters& bp = *block.bp;
   bool chroma_from_luma_allowed;
   if (frame_header_.segmentation.lossless[bp.segment_id]) {
-    chroma_from_luma_allowed =
-        kPlaneResidualSize[block.size][subsampling_x_[kPlaneU]]
-                          [subsampling_y_[kPlaneU]] == kBlock4x4;
+    chroma_from_luma_allowed = block.residual_size[kPlaneTypeUV] == kBlock4x4;
   } else {
     chroma_from_luma_allowed = IsBlockDimensionLessThan64(block.size);
   }
@@ -1148,12 +1161,11 @@
       prediction_parameters.wedge_index =
           reader_.ReadSymbol<kWedgeIndexSymbolCount>(
               symbol_decoder_context_.wedge_index_cdf[block.size]);
-      prediction_parameters.wedge_sign =
-          static_cast<int>(reader_.ReadLiteral(1));
+      prediction_parameters.wedge_sign = static_cast<int>(reader_.ReadBit());
     } else if (prediction_parameters.compound_prediction_type ==
                kCompoundPredictionTypeDiffWeighted) {
       prediction_parameters.mask_is_inverse =
-          static_cast<bool>(reader_.ReadLiteral(1));
+          static_cast<bool>(reader_.ReadBit());
     }
     return;
   }
diff --git a/libgav1/src/tile/bitstream/palette.cc b/libgav1/src/tile/bitstream/palette.cc
index 133b016..2438ca9 100644
--- a/libgav1/src/tile/bitstream/palette.cc
+++ b/libgav1/src/tile/bitstream/palette.cc
@@ -12,88 +12,34 @@
 #include "src/utils/common.h"
 #include "src/utils/constants.h"
 #include "src/utils/entropy_decoder.h"
+#include "src/utils/memory.h"
 #include "src/utils/types.h"
 
 namespace libgav1 {
-namespace {
-
-// Add |value| to the |cache| if it doesn't already exist.
-void MaybeAddToPaletteCache(uint16_t value, uint16_t* const cache,
-                            int* const n) {
-  assert(cache != nullptr);
-  assert(n != nullptr);
-  if (*n > 0 && value == cache[*n - 1]) return;
-  cache[(*n)++] = value;
-}
-
-// Palette colors are generated using two ascending arrays. So sorting them is
-// simply a matter of finding the pivot point and merging two sorted arrays.
-void SortPaletteColors(uint16_t* const color, const int size, const int pivot) {
-  if (pivot == 0 || pivot == size || color[pivot - 1] < color[pivot]) {
-    // The array is already sorted.
-    return;
-  }
-  uint16_t temp_color[kMaxPaletteSize];
-  memcpy(temp_color, color, size * sizeof(color[0]));
-  int i = 0;
-  int j = pivot;
-  int k = 0;
-  while (i < pivot && j < size) {
-    if (temp_color[i] < temp_color[j]) {
-      color[k++] = temp_color[i++];
-    } else {
-      color[k++] = temp_color[j++];
-    }
-  }
-  while (i < pivot) {
-    color[k++] = temp_color[i++];
-  }
-  while (j < size) {
-    color[k++] = temp_color[j++];
-  }
-}
-
-}  // namespace
 
 int Tile::GetPaletteCache(const Block& block, PlaneType plane_type,
                           uint16_t* const cache) {
-  const int top_n =
+  const int top_size =
       (block.top_available && Mod64(MultiplyBy4(block.row4x4)) != 0)
           ? block.bp_top->palette_mode_info.size[plane_type]
           : 0;
-  const int left_n = block.left_available
-                         ? block.bp_left->palette_mode_info.size[plane_type]
-                         : 0;
-  int top_index = 0;
-  int left_index = 0;
-  int n = 0;
-  while (top_index < top_n && left_index < left_n) {
-    const int top_color =
-        block.bp_top->palette_mode_info.color[plane_type][top_index];
-    const int left_color =
-        block.bp_left->palette_mode_info.color[plane_type][left_index];
-    if (left_color < top_color) {
-      MaybeAddToPaletteCache(left_color, cache, &n);
-      ++left_index;
-    } else {
-      MaybeAddToPaletteCache(top_color, cache, &n);
-      ++top_index;
-      if (top_color == left_color) ++left_index;
-    }
-  }
-  while (top_index < top_n) {
-    MaybeAddToPaletteCache(
-        block.bp_top->palette_mode_info.color[plane_type][top_index], cache,
-        &n);
-    ++top_index;
-  }
-  while (left_index < left_n) {
-    MaybeAddToPaletteCache(
-        block.bp_left->palette_mode_info.color[plane_type][left_index], cache,
-        &n);
-    ++left_index;
-  }
-  return n;
+  const int left_size = block.left_available
+                            ? block.bp_left->palette_mode_info.size[plane_type]
+                            : 0;
+  if (left_size == 0 && top_size == 0) return 0;
+  // Merge the left and top colors in sorted order and store them in |cache|.
+  uint16_t dummy[1];
+  uint16_t* top = (top_size > 0)
+                      ? block.bp_top->palette_mode_info.color[plane_type]
+                      : dummy;
+  uint16_t* left = (left_size > 0)
+                       ? block.bp_left->palette_mode_info.color[plane_type]
+                       : dummy;
+  std::merge(top, top + top_size, left, left + left_size, cache);
+  // Deduplicate the entries in |cache| and return the number of unique
+  // entries.
+  return static_cast<int>(
+      std::distance(cache, std::unique(cache, cache + left_size + top_size)));
 }
 
 void Tile::ReadPaletteColors(const Block& block, Plane plane) {
@@ -115,32 +61,49 @@
     palette_color[index++] =
         static_cast<uint16_t>(reader_.ReadLiteral(bitdepth));
   }
+  const int max_value = (1 << bitdepth) - 1;
   if (index < palette_size) {
     int bits = bitdepth - 3 + static_cast<int>(reader_.ReadLiteral(2));
-    for (; index < palette_size; ++index) {
+    do {
       const int delta = static_cast<int>(reader_.ReadLiteral(bits)) +
                         (plane_type == kPlaneTypeY ? 1 : 0);
       palette_color[index] =
-          Clip3(palette_color[index - 1] + delta, 0, (1 << bitdepth) - 1);
+          std::min(palette_color[index - 1] + delta, max_value);
+      if (palette_color[index] + (plane_type == kPlaneTypeY ? 1 : 0) >=
+          max_value) {
+        // Once the color exceeds max_value, all others can be set to max_value
+        // (since they are computed as a delta on top of the current color and
+        // then clipped).
+        Memset(&palette_color[index + 1], max_value, palette_size - index - 1);
+        break;
+      }
       const int range = (1 << bitdepth) - palette_color[index] -
                         (plane_type == kPlaneTypeY ? 1 : 0);
       bits = std::min(bits, CeilLog2(range));
-    }
+    } while (++index < palette_size);
   }
-  SortPaletteColors(palette_color, palette_size, merge_pivot);
+  // Palette colors are generated using two ascending arrays. So sorting them is
+  // simply a matter of merging the two sorted portions of the array.
+  std::inplace_merge(palette_color, palette_color + merge_pivot,
+                     palette_color + palette_size);
   if (plane_type == kPlaneTypeUV) {
     uint16_t* const palette_color_v = bp.palette_mode_info.color[kPlaneV];
     if (reader_.ReadBit() != 0) {  // delta_encode_palette_colors_v.
-      const int max_value = 1 << bitdepth;
       const int bits = bitdepth - 4 + static_cast<int>(reader_.ReadLiteral(2));
       palette_color_v[0] = reader_.ReadLiteral(bitdepth);
       for (int i = 1; i < palette_size; ++i) {
         int delta = static_cast<int>(reader_.ReadLiteral(bits));
         if (delta != 0 && reader_.ReadBit() != 0) delta = -delta;
-        int value = palette_color_v[i - 1] + delta;
-        if (value < 0) value += max_value;
-        if (value >= max_value) value -= max_value;
-        palette_color_v[i] = Clip3(value, 0, (1 << bitdepth) - 1);
+        // This line is equivalent to the following lines in the spec:
+        // val = palette_colors_v[ idx - 1 ] + palette_delta_v
+        // if ( val < 0 ) val += maxVal
+        // if ( val >= maxVal ) val -= maxVal
+        // palette_colors_v[ idx ] = Clip1( val )
+        //
+        // The difference is that in the code, max_value is (1 << bitdepth) - 1.
+        // So "& max_value" has the desired effect of computing both the "if"
+        // conditions and the Clip.
+        palette_color_v[i] = (palette_color_v[i - 1] + delta) & max_value;
       }
     } else {
       for (int i = 0; i < palette_size; ++i) {
@@ -151,19 +114,6 @@
   }
 }
 
-int Tile::GetHasPaletteYContext(const Block& block) const {
-  int context = 0;
-  if (block.top_available &&
-      block.bp_top->palette_mode_info.size[kPlaneTypeY] > 0) {
-    ++context;
-  }
-  if (block.left_available &&
-      block.bp_left->palette_mode_info.size[kPlaneTypeY] > 0) {
-    ++context;
-  }
-  return context;
-}
-
 void Tile::ReadPaletteModeInfo(const Block& block) {
   BlockParameters& bp = *block.bp;
   if (IsBlockSmallerThan8x8(block.size) || block.size > kBlock64x64 ||
@@ -175,7 +125,13 @@
   const int block_size_context =
       k4x4WidthLog2[block.size] + k4x4HeightLog2[block.size] - 2;
   if (bp.y_mode == kPredictionModeDc) {
-    const int context = GetHasPaletteYContext(block);
+    const int context =
+        static_cast<int>(block.top_available &&
+                         block.bp_top->palette_mode_info.size[kPlaneTypeY] >
+                             0) +
+        static_cast<int>(block.left_available &&
+                         block.bp_left->palette_mode_info.size[kPlaneTypeY] >
+                             0);
     const bool has_palette_y = reader_.ReadSymbol(
         symbol_decoder_context_.has_palette_y_cdf[block_size_context][context]);
     if (has_palette_y) {
@@ -186,8 +142,7 @@
       ReadPaletteColors(block, kPlaneY);
     }
   }
-  if (PlaneCount() > 1 && bp.uv_mode == kPredictionModeDc &&
-      block.HasChroma()) {
+  if (bp.uv_mode == kPredictionModeDc && block.HasChroma()) {
     const int context =
         static_cast<int>(bp.palette_mode_info.size[kPlaneTypeY] > 0);
     const bool has_palette_uv =
diff --git a/libgav1/src/tile/prediction.cc b/libgav1/src/tile/prediction.cc
index 6759838..dbf314d 100644
--- a/libgav1/src/tile/prediction.cc
+++ b/libgav1/src/tile/prediction.cc
@@ -63,30 +63,20 @@
   return kDirectionalIntraPredictorDerivative[DivideBy2(angle) - 1];
 }
 
+// Maps the block_size to an index as follows:
+//  kBlock8x8 => 0.
+//  kBlock8x16 => 1.
+//  kBlock8x32 => 2.
+//  kBlock16x8 => 3.
+//  kBlock16x16 => 4.
+//  kBlock16x32 => 5.
+//  kBlock32x8 => 6.
+//  kBlock32x16 => 7.
+//  kBlock32x32 => 8.
 int GetWedgeBlockSizeIndex(BlockSize block_size) {
   assert(block_size >= kBlock8x8);
-  switch (block_size) {
-    case kBlock8x8:
-      return 0;
-    case kBlock8x16:
-      return 1;
-    case kBlock8x32:
-      return 2;
-    case kBlock16x8:
-      return 3;
-    case kBlock16x16:
-      return 4;
-    case kBlock16x32:
-      return 5;
-    case kBlock32x8:
-      return 6;
-    case kBlock32x16:
-      return 7;
-    case kBlock32x32:
-      return 8;
-    default:
-      return -1;
-  }
+  return block_size - kBlock8x8 - static_cast<int>(block_size >= kBlock16x8) -
+         static_cast<int>(block_size >= kBlock32x8);
 }
 
 // 7.11.2.9.
@@ -211,19 +201,15 @@
   }
 }
 
-// 7.11.3.2.
-void GetInterRoundingBits(const bool is_compound, const int bitdepth,
-                          uint8_t round_bits[2]) {
-  round_bits[0] = 3;
-  round_bits[1] = is_compound ? 7 : 11;
+// 7.11.3.2. Note InterRoundBits0 is derived in the dsp layer.
+int GetInterRoundingBits(const bool is_compound, const int bitdepth) {
+  if (is_compound) return 7;
 #if LIBGAV1_MAX_BITDEPTH == 12
-  if (bitdepth == 12) {
-    round_bits[0] += 2;
-    if (!is_compound) round_bits[1] -= 2;
-  }
+  if (bitdepth == 12) return 9;
 #else
   static_cast<void>(bitdepth);
 #endif
+  return 11;
 }
 
 uint8_t* GetStartPoint(Array2DView<uint8_t>* const buffer, const int plane,
@@ -557,23 +543,23 @@
   Array2DView<Pixel> y_buffer(
       buffer_[kPlaneY].rows(), buffer_[kPlaneY].columns() / sizeof(Pixel),
       reinterpret_cast<Pixel*>(&buffer_[kPlaneY][0][0]));
-  if (!block.sb_buffer->cfl_luma_buffer_valid) {
+  if (!block.scratch_buffer->cfl_luma_buffer_valid) {
     const int luma_x = start_x << subsampling_x;
     const int luma_y = start_y << subsampling_y;
     dsp_.cfl_subsamplers[tx_size][subsampling_x + subsampling_y](
-        block.sb_buffer->cfl_luma_buffer,
+        block.scratch_buffer->cfl_luma_buffer,
         prediction_parameters.max_luma_width - luma_x,
         prediction_parameters.max_luma_height - luma_y,
         reinterpret_cast<uint8_t*>(&y_buffer[luma_y][luma_x]),
         buffer_[kPlaneY].columns());
-    block.sb_buffer->cfl_luma_buffer_valid = true;
+    block.scratch_buffer->cfl_luma_buffer_valid = true;
   }
   Array2DView<Pixel> buffer(buffer_[plane].rows(),
                             buffer_[plane].columns() / sizeof(Pixel),
                             reinterpret_cast<Pixel*>(&buffer_[plane][0][0]));
   dsp_.cfl_intra_predictors[tx_size](
       reinterpret_cast<uint8_t*>(&buffer[start_y][start_x]),
-      buffer_[plane].columns(), block.sb_buffer->cfl_luma_buffer,
+      buffer_[plane].columns(), block.scratch_buffer->cfl_luma_buffer,
       (plane == kPlaneU) ? prediction_parameters.cfl_alpha_u
                          : prediction_parameters.cfl_alpha_v);
 }
@@ -587,7 +573,7 @@
     const TransformSize tx_size);
 #endif
 
-bool Tile::InterIntraPrediction(
+void Tile::InterIntraPrediction(
     uint16_t* prediction[2], const ptrdiff_t prediction_stride,
     const uint8_t* const prediction_mask,
     const ptrdiff_t prediction_mask_stride,
@@ -602,7 +588,6 @@
              kCompoundPredictionTypeWedge);
   // The first buffer of InterIntra is from inter prediction.
   // The second buffer is from intra prediction.
-  Array2D<uint16_t> intra_prediction;
   ptrdiff_t intra_stride;
   const int bitdepth = sequence_header_.color_config.bitdepth;
   if (bitdepth == 8) {
@@ -610,20 +595,16 @@
     // 8, |buffer_| is uint8_t and hence a copy has to be made. For higher
     // bitdepths, the |buffer_| itself can act as an uint16_t buffer so no
     // copy is necessary.
-    if (!intra_prediction.Reset(prediction_height, prediction_width)) {
-      LIBGAV1_DLOG(ERROR,
-                   "Can't allocate memory for the intra prediction block.");
-      return false;
-    }
     uint8_t* dest_ptr = dest;
+    Array2DView<uint16_t> intra_prediction(
+        kMaxSuperBlockSizeInPixels, kMaxSuperBlockSizeInPixels, prediction[1]);
     for (int r = 0; r < prediction_height; ++r) {
       for (int c = 0; c < prediction_width; ++c) {
         intra_prediction[r][c] = dest_ptr[c];
       }
       dest_ptr += dest_stride;
     }
-    prediction[1] = intra_prediction.data();
-    intra_stride = prediction_width;
+    intra_stride = kMaxSuperBlockSizeInPixels;
   } else {
     prediction[1] = reinterpret_cast<uint16_t*>(dest);
     intra_stride = dest_stride / sizeof(uint16_t);
@@ -634,23 +615,9 @@
                                   prediction[1], intra_stride, prediction_mask,
                                   prediction_mask_stride, prediction_width,
                                   prediction_height, dest, dest_stride);
-  return true;
 }
 
-bool Tile::AllocatePredictionMask(SuperBlockBuffer* sb_buffer) {
-  if (sb_buffer->prediction_mask != nullptr) {
-    return true;
-  }
-  sb_buffer->prediction_mask = MakeAlignedUniquePtr<uint8_t>(
-      16, kMaxSuperBlockSizeInPixels * kMaxSuperBlockSizeInPixels);
-  if (sb_buffer->prediction_mask == nullptr) {
-    LIBGAV1_DLOG(ERROR, "Allocation of prediction_mask failed.");
-    return false;
-  }
-  return true;
-}
-
-bool Tile::CompoundInterPrediction(
+void Tile::CompoundInterPrediction(
     const Block& block, const ptrdiff_t prediction_stride,
     const ptrdiff_t prediction_mask_stride, const int prediction_width,
     const int prediction_height, const Plane plane, const int subsampling_x,
@@ -658,33 +625,30 @@
     const int candidate_column, uint8_t* dest, const ptrdiff_t dest_stride) {
   const PredictionParameters& prediction_parameters =
       *block.bp->prediction_parameters;
-  uint16_t* prediction[2] = {block.sb_buffer->prediction_buffer[0].get(),
-                             block.sb_buffer->prediction_buffer[1].get()};
+  uint16_t* prediction[2] = {block.scratch_buffer->prediction_buffer[0],
+                             block.scratch_buffer->prediction_buffer[1]};
   switch (prediction_parameters.compound_prediction_type) {
     case kCompoundPredictionTypeWedge:
       GetMaskBlendFunc(dsp_, prediction_parameters.inter_intra_mode,
                        prediction_parameters.is_wedge_inter_intra,
                        subsampling_x, subsampling_y)(
           prediction[0], prediction_stride, prediction[1], prediction_stride,
-          block.sb_buffer->prediction_mask.get(), prediction_mask_stride,
+          block.scratch_buffer->prediction_mask, prediction_mask_stride,
           prediction_width, prediction_height, dest, dest_stride);
       break;
     case kCompoundPredictionTypeDiffWeighted:
       if (plane == kPlaneY) {
-        if (!AllocatePredictionMask(block.sb_buffer)) {
-          return false;
-        }
         GenerateWeightMask(
             prediction[0], prediction_stride, prediction[1], prediction_stride,
             prediction_parameters.mask_is_inverse, prediction_width,
-            prediction_height, bitdepth, block.sb_buffer->prediction_mask.get(),
+            prediction_height, bitdepth, block.scratch_buffer->prediction_mask,
             prediction_mask_stride);
       }
       GetMaskBlendFunc(dsp_, prediction_parameters.inter_intra_mode,
                        prediction_parameters.is_wedge_inter_intra,
                        subsampling_x, subsampling_y)(
           prediction[0], prediction_stride, prediction[1], prediction_stride,
-          block.sb_buffer->prediction_mask.get(), prediction_mask_stride,
+          block.scratch_buffer->prediction_mask, prediction_mask_stride,
           prediction_width, prediction_height, dest, dest_stride);
       break;
     case kCompoundPredictionTypeDistance:
@@ -693,16 +657,14 @@
           prediction_width, prediction_height, candidate_row, candidate_column,
           dest, dest_stride);
       break;
-    case kCompoundPredictionTypeAverage:
+    default:
+      assert(prediction_parameters.compound_prediction_type ==
+             kCompoundPredictionTypeAverage);
       dsp_.average_blend(prediction[0], prediction_stride, prediction[1],
                          prediction_stride, prediction_width, prediction_height,
                          dest, dest_stride);
       break;
-    default:
-      assert(false && "This is not a compound type.\n");
-      return false;
   }
-  return true;
 }
 
 GlobalMotion* Tile::GetWarpParams(
@@ -745,7 +707,7 @@
   return nullptr;
 }
 
-bool Tile::InterPrediction(const Block& block, const Plane plane, const int x,
+void Tile::InterPrediction(const Block& block, const Plane plane, const int x,
                            const int y, const int prediction_width,
                            const int prediction_height, int candidate_row,
                            int candidate_column, bool* const is_local_valid,
@@ -760,38 +722,12 @@
       bp.is_inter && bp.reference_frame[1] == kReferenceFrameIntra;
   const ptrdiff_t prediction_stride = prediction_width;
 
-  const size_t prediction_buffer_size = prediction_height * prediction_stride;
-  if (block.sb_buffer->prediction_buffer_size[0] < prediction_buffer_size) {
-    block.sb_buffer->prediction_buffer[0] =
-        MakeAlignedUniquePtr<uint16_t>(8, prediction_buffer_size);
-    if (block.sb_buffer->prediction_buffer[0] == nullptr) {
-      block.sb_buffer->prediction_buffer_size[0] = 0;
-      LIBGAV1_DLOG(ERROR,
-                   "Can't allocate memory for the first prediction block.");
-      return false;
-    }
-    block.sb_buffer->prediction_buffer_size[0] = prediction_buffer_size;
-  }
-  if (is_compound &&
-      block.sb_buffer->prediction_buffer_size[1] < prediction_buffer_size) {
-    block.sb_buffer->prediction_buffer[1] =
-        MakeAlignedUniquePtr<uint16_t>(8, prediction_buffer_size);
-    if (block.sb_buffer->prediction_buffer[1] == nullptr) {
-      block.sb_buffer->prediction_buffer_size[1] = 0;
-      LIBGAV1_DLOG(ERROR,
-                   "Can't allocate memory for the second prediction block.");
-      return false;
-    }
-    block.sb_buffer->prediction_buffer_size[1] = prediction_buffer_size;
-  }
-
   const PredictionParameters& prediction_parameters =
       *block.bp->prediction_parameters;
   uint8_t* const dest = GetStartPoint(buffer_, plane, x, y, bitdepth);
   const ptrdiff_t dest_stride = buffer_[plane].columns();  // In bytes.
-  uint8_t round_bits[2];
-  GetInterRoundingBits(is_compound, sequence_header_.color_config.bitdepth,
-                       round_bits);
+  const int round_bits =
+      GetInterRoundingBits(is_compound, sequence_header_.color_config.bitdepth);
   for (int index = 0; index < 1 + static_cast<int>(is_compound); ++index) {
     const ReferenceFrameType reference_type =
         bp_reference.reference_frame[index];
@@ -802,12 +738,10 @@
                       prediction_parameters, reference_type, is_local_valid,
                       &global_motion_params, local_warp_params);
     if (warp_params != nullptr) {
-      if (!BlockWarpProcess(block, plane, index, x, y, prediction_width,
-                            prediction_height, prediction_stride, warp_params,
-                            round_bits, is_compound, is_inter_intra, dest,
-                            dest_stride)) {
-        return false;
-      }
+      BlockWarpProcess(block, plane, index, x, y, prediction_width,
+                       prediction_height, prediction_stride, warp_params,
+                       round_bits, is_compound, is_inter_intra, dest,
+                       dest_stride);
     } else {
       const int reference_index =
           prediction_parameters.use_intra_block_copy
@@ -815,9 +749,9 @@
               : frame_header_.reference_frame_index[reference_type -
                                                     kReferenceFrameLast];
       BlockInterPrediction(
-          plane, reference_index, bp_reference.mv[index], x, y,
+          block, plane, reference_index, bp_reference.mv[index], x, y,
           prediction_width, prediction_height, candidate_row, candidate_column,
-          block.sb_buffer->prediction_buffer[index].get(), prediction_stride,
+          block.scratch_buffer->prediction_buffer[index], prediction_stride,
           round_bits, is_compound, is_inter_intra, dest, dest_stride);
     }
   }
@@ -838,32 +772,23 @@
     const int offset = block_size_index * wedge_mask_stride_3 +
                        prediction_parameters.wedge_sign * wedge_mask_stride_2 +
                        prediction_parameters.wedge_index * wedge_mask_stride_1;
-    if (!AllocatePredictionMask(block.sb_buffer)) {
-      return false;
-    }
     PopulatePredictionMaskFromWedgeMask(
         &wedge_masks_[offset], kWedgeMaskMasterSize, prediction_width,
-        prediction_height, block.sb_buffer->prediction_mask.get(),
+        prediction_height, block.scratch_buffer->prediction_mask,
         kMaxSuperBlockSizeInPixels);
   } else if (prediction_parameters.compound_prediction_type ==
              kCompoundPredictionTypeIntra) {
-    if (plane == kPlaneY) {
-      if (!AllocatePredictionMask(block.sb_buffer)) {
-        return false;
-      }
-    }
     GenerateInterIntraMask(prediction_parameters.inter_intra_mode,
                            prediction_width, prediction_height,
-                           block.sb_buffer->prediction_mask.get(),
+                           block.scratch_buffer->prediction_mask,
                            prediction_mask_stride);
   }
 
-  bool ok = true;
   if (is_compound) {
-    ok = CompoundInterPrediction(
-        block, prediction_stride, prediction_mask_stride, prediction_width,
-        prediction_height, plane, subsampling_x, subsampling_y, bitdepth,
-        candidate_row, candidate_column, dest, dest_stride);
+    CompoundInterPrediction(block, prediction_stride, prediction_mask_stride,
+                            prediction_width, prediction_height, plane,
+                            subsampling_x, subsampling_y, bitdepth,
+                            candidate_row, candidate_column, dest, dest_stride);
   } else {
     if (prediction_parameters.motion_mode == kMotionModeObmc) {
       // Obmc mode is allowed only for single reference (!is_compound).
@@ -872,25 +797,25 @@
     } else if (is_inter_intra) {
       // InterIntra and obmc must be mutually exclusive.
       uint16_t* prediction_ptr[2] = {
-          block.sb_buffer->prediction_buffer[0].get(),
-          block.sb_buffer->prediction_buffer[1].get()};
-      ok = InterIntraPrediction(
-          prediction_ptr, prediction_stride,
-          block.sb_buffer->prediction_mask.get(), prediction_mask_stride,
-          prediction_parameters, prediction_width, prediction_height,
-          subsampling_x, subsampling_y, dest, dest_stride);
+          block.scratch_buffer->prediction_buffer[0],
+          block.scratch_buffer->prediction_buffer[1]};
+      InterIntraPrediction(prediction_ptr, prediction_stride,
+                           block.scratch_buffer->prediction_mask,
+                           prediction_mask_stride, prediction_parameters,
+                           prediction_width, prediction_height, subsampling_x,
+                           subsampling_y, dest, dest_stride);
     }
   }
-  return ok;
 }
 
-void Tile::ObmcBlockPrediction(const MotionVector& mv, const Plane plane,
+void Tile::ObmcBlockPrediction(const Block& block, const MotionVector& mv,
+                               const Plane plane,
                                const int reference_frame_index, const int width,
                                const int height, const int x, const int y,
                                const int candidate_row,
                                const int candidate_column,
                                const ObmcDirection blending_direction,
-                               const uint8_t* const round_bits) {
+                               const int round_bits) {
   const int bitdepth = sequence_header_.color_config.bitdepth;
   // Obmc's prediction needs to be clipped before blending with above/left
   // prediction blocks.
@@ -901,8 +826,8 @@
   ];
   const ptrdiff_t obmc_clipped_prediction_stride =
       (bitdepth == 8) ? width : width * sizeof(uint16_t);
-  BlockInterPrediction(plane, reference_frame_index, mv, x, y, width, height,
-                       candidate_row, candidate_column, nullptr, width,
+  BlockInterPrediction(block, plane, reference_frame_index, mv, x, y, width,
+                       height, candidate_row, candidate_column, nullptr, width,
                        round_bits, false, false, obmc_clipped_prediction,
                        obmc_clipped_prediction_stride);
 
@@ -915,16 +840,14 @@
 
 void Tile::ObmcPrediction(const Block& block, const Plane plane,
                           const int width, const int height,
-                          const uint8_t* const round_bits) {
+                          const int round_bits) {
   const int subsampling_x = subsampling_x_[plane];
   const int subsampling_y = subsampling_y_[plane];
-  const BlockSize plane_block_size =
-      kPlaneResidualSize[block.size][subsampling_x][subsampling_y];
-  assert(plane_block_size != kBlockInvalid);
   const int num4x4_wide = kNum4x4BlocksWide[block.size];
   const int num4x4_high = kNum4x4BlocksHigh[block.size];
 
-  if (block.top_available && !IsBlockSmallerThan8x8(plane_block_size)) {
+  if (block.top_available &&
+      !IsBlockSmallerThan8x8(block.residual_size[GetPlaneType(plane)])) {
     const int num_limit = std::min(uint8_t{4}, k4x4WidthLog2[block.size]);
     const int column4x4_max =
         std::min(block.column4x4 + num4x4_wide, frame_header_.columns4x4);
@@ -946,7 +869,7 @@
                                                 kReferenceFrameLast];
         const int prediction_width =
             std::min(width, MultiplyBy4(step) >> subsampling_x);
-        ObmcBlockPrediction(bp_top.mv[0], plane,
+        ObmcBlockPrediction(block, bp_top.mv[0], plane,
                             candidate_reference_frame_index, prediction_width,
                             prediction_height,
                             MultiplyBy4(column4x4) >> subsampling_x,
@@ -979,7 +902,7 @@
         const int prediction_height =
             std::min(height, MultiplyBy4(step) >> subsampling_y);
         ObmcBlockPrediction(
-            bp_left.mv[0], plane, candidate_reference_frame_index,
+            block, bp_left.mv[0], plane, candidate_reference_frame_index,
             prediction_width, prediction_height, block_start_x,
             MultiplyBy4(row4x4) >> subsampling_y, candidate_row,
             candidate_column, kObmcDirectionHorizontal, round_bits);
@@ -1123,13 +1046,12 @@
 }
 
 void Tile::BlockInterPrediction(
-    const Plane plane, const int reference_frame_index, const MotionVector& mv,
-    const int x, const int y, const int width, const int height,
-    const int candidate_row, const int candidate_column,
+    const Block& block, const Plane plane, const int reference_frame_index,
+    const MotionVector& mv, const int x, const int y, const int width,
+    const int height, const int candidate_row, const int candidate_column,
     uint16_t* const prediction, const ptrdiff_t prediction_stride,
-    const uint8_t* const round_bits, const bool is_compound,
-    const bool is_inter_intra, uint8_t* const dest,
-    const ptrdiff_t dest_stride) {
+    const int round_bits, const bool is_compound, const bool is_inter_intra,
+    uint8_t* const dest, const ptrdiff_t dest_stride) {
   const BlockParameters& bp =
       *block_parameters_holder_.Find(candidate_row, candidate_column);
   int start_x;
@@ -1179,7 +1101,6 @@
       reference_buffer->right_border(plane),
       reference_buffer->bottom_border(plane), &ref_block_start_x,
       &ref_block_start_y, &ref_block_end_x, &ref_block_end_y);
-  AlignedUniquePtr<uint8_t> block_buffer;
   const uint8_t* block_start = nullptr;
   ptrdiff_t block_stride;
   if (!extend_block) {
@@ -1204,34 +1125,22 @@
     block_stride =
         (2 * width + kConvolveBorderLeftTop + kConvolveBorderRightBottom) *
         pixel_size;
-    const int alignment = 16;
-    int block_height =
-        height + kConvolveBorderLeftTop + kConvolveBorderRightBottom;
-    if (is_scaled) {
-      block_height = (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >>
-                      kScaleSubPixelBits) +
-                     kSubPixelTaps;
-    }
-    block_buffer =
-        MakeAlignedUniquePtr<uint8_t>(alignment, block_stride * block_height);
-    if (block_buffer == nullptr) {
-      LIBGAV1_DLOG(ERROR, "Can't allocate memory for the reference block.");
-      return;
-    }
     if (bitdepth == 8) {
       BuildConvolveBlock<uint8_t>(
           plane, reference_frame_index, is_scaled, height, ref_start_x,
           ref_last_x, ref_start_y, ref_last_y, step_y, ref_block_start_x,
-          ref_block_end_x, ref_block_start_y, block_buffer.get(), block_stride);
+          ref_block_end_x, ref_block_start_y,
+          block.scratch_buffer->convolve_block_buffer, block_stride);
 #if LIBGAV1_MAX_BITDEPTH >= 10
     } else {
       BuildConvolveBlock<uint16_t>(
           plane, reference_frame_index, is_scaled, height, ref_start_x,
           ref_last_x, ref_start_y, ref_last_y, step_y, ref_block_start_x,
-          ref_block_end_x, ref_block_start_y, block_buffer.get(), block_stride);
+          ref_block_end_x, ref_block_start_y,
+          block.scratch_buffer->convolve_block_buffer, block_stride);
 #endif
     }
-    block_start = block_buffer.get() +
+    block_start = block.scratch_buffer->convolve_block_buffer +
                   (is_scaled ? 0
                              : kConvolveBorderLeftTop * block_stride +
                                    kConvolveBorderLeftTop * pixel_size);
@@ -1256,18 +1165,18 @@
     convolve_func = dsp_.convolve[0][1][1][1];
   }
   convolve_func(block_start, block_stride, horizontal_filter_index,
-                vertical_filter_index, round_bits[1], start_x, start_y, step_x,
+                vertical_filter_index, round_bits, start_x, start_y, step_x,
                 step_y, width, height, output, output_stride);
 }
 
-bool Tile::BlockWarpProcess(const Block& block, const Plane plane,
+void Tile::BlockWarpProcess(const Block& block, const Plane plane,
                             const int index, const int block_start_x,
                             const int block_start_y, const int width,
                             const int height, const ptrdiff_t prediction_stride,
                             GlobalMotion* const warp_params,
-                            const uint8_t* const round_bits,
-                            const bool is_compound, const bool is_inter_intra,
-                            uint8_t* const dest, const ptrdiff_t dest_stride) {
+                            const int round_bits, const bool is_compound,
+                            const bool is_inter_intra, uint8_t* const dest,
+                            const ptrdiff_t dest_stride) {
   assert(width >= 8 && height >= 8);
   const BlockParameters& bp = *block.bp;
   const int reference_frame_index =
@@ -1283,10 +1192,10 @@
   const int source_height =
       reference_frames_[reference_frame_index]->buffer()->displayed_height(
           plane);
-  uint16_t* const prediction = block.sb_buffer->prediction_buffer[index].get();
+  uint16_t* const prediction = block.scratch_buffer->prediction_buffer[index];
   dsp_.warp(source, source_stride, source_width, source_height,
             warp_params->params, subsampling_x_[plane], subsampling_y_[plane],
-            round_bits[1], block_start_x, block_start_y, width, height,
+            round_bits, block_start_x, block_start_y, width, height,
             warp_params->alpha, warp_params->beta, warp_params->gamma,
             warp_params->delta, prediction, prediction_stride);
   if (!is_compound && !is_inter_intra) {
@@ -1301,7 +1210,6 @@
 #endif
     }
   }
-  return true;
 }
 
 }  // namespace libgav1
diff --git a/libgav1/src/tile/tile.cc b/libgav1/src/tile/tile.cc
index 378410d..4e94e17 100644
--- a/libgav1/src/tile/tile.cc
+++ b/libgav1/src/tile/tile.cc
@@ -200,6 +200,18 @@
     kTransformSize32x32, kTransformSize32x16, kTransformSize32x32,
     kTransformSize32x32};
 
+// This is the same as Max_Tx_Size_Rect array in the spec but with *x64 and 64*x
+// transforms replaced with *x32 and 32x* respectively.
+constexpr TransformSize kUVTransformSize[kMaxBlockSizes] = {
+    kTransformSize4x4,   kTransformSize4x8,   kTransformSize4x16,
+    kTransformSize8x4,   kTransformSize8x8,   kTransformSize8x16,
+    kTransformSize8x32,  kTransformSize16x4,  kTransformSize16x8,
+    kTransformSize16x16, kTransformSize16x32, kTransformSize16x32,
+    kTransformSize32x8,  kTransformSize32x16, kTransformSize32x32,
+    kTransformSize32x32, kTransformSize32x16, kTransformSize32x32,
+    kTransformSize32x32, kTransformSize32x32, kTransformSize32x32,
+    kTransformSize32x32};
+
 // ith entry of this array is computed as:
 // DivideBy2(TransformSizeToSquareTransformIndex(kTransformSizeSquareMin[i]) +
 //           TransformSizeToSquareTransformIndex(kTransformSizeSquareMax[i]) +
@@ -270,32 +282,6 @@
   return std::min(length, max - start);
 }
 
-// This is the same as Max_Tx_Size_Rect array in the spec but with *x64 and 64*x
-// transforms replaced with *x32 and 32x* respectively.
-constexpr TransformSize kUVTransformSize[kMaxBlockSizes] = {
-    kTransformSize4x4,   kTransformSize4x8,   kTransformSize4x16,
-    kTransformSize8x4,   kTransformSize8x8,   kTransformSize8x16,
-    kTransformSize8x32,  kTransformSize16x4,  kTransformSize16x8,
-    kTransformSize16x16, kTransformSize16x32, kTransformSize16x32,
-    kTransformSize32x8,  kTransformSize32x16, kTransformSize32x32,
-    kTransformSize32x32, kTransformSize32x16, kTransformSize32x32,
-    kTransformSize32x32, kTransformSize32x32, kTransformSize32x32,
-    kTransformSize32x32};
-
-// 5.11.37.
-TransformSize GetTransformSize(bool lossless, BlockSize block_size, Plane plane,
-                               TransformSize tx_size, int subsampling_x,
-                               int subsampling_y) {
-  if (plane == kPlaneY) return tx_size;
-  // For Y Plane, |tx_size| is always kTransformSize4x4. So it is sufficient to
-  // have this special case for |lossless| only for U and V planes.
-  if (lossless) return kTransformSize4x4;
-  const BlockSize plane_size =
-      kPlaneResidualSize[block_size][subsampling_x][subsampling_y];
-  assert(plane_size != kBlockInvalid);
-  return kUVTransformSize[plane_size];
-}
-
 void SetTransformType(const Tile::Block& block, int x4, int y4, int w4, int h4,
                       TransformType tx_type,
                       TransformType transform_types[32][32]) {
@@ -327,6 +313,7 @@
     Array2D<TransformSize>* const inter_transform_sizes,
     const dsp::Dsp* const dsp, ThreadPool* const thread_pool,
     ResidualBufferPool* const residual_buffer_pool,
+    DecoderScratchBufferPool* const decoder_scratch_buffer_pool,
     BlockingCounterWithStatus* const pending_tiles)
     : number_(tile_number),
       data_(data),
@@ -365,6 +352,7 @@
       inter_transform_sizes_(*inter_transform_sizes),
       thread_pool_(thread_pool),
       residual_buffer_pool_(residual_buffer_pool),
+      decoder_scratch_buffer_pool_(decoder_scratch_buffer_pool),
       pending_tiles_(pending_tiles),
       build_bit_mask_when_parsing_(false) {
   row_ = number_ / frame_header.tile_info.tile_columns;
@@ -449,12 +437,19 @@
     if (!ThreadedDecode()) return false;
   } else {
     const int block_width4x4 = kNum4x4BlocksWide[SuperBlockSize()];
-    SuperBlockBuffer sb_buffer;
+    std::unique_ptr<DecoderScratchBuffer> scratch_buffer =
+        decoder_scratch_buffer_pool_->Get();
+    if (scratch_buffer == nullptr) {
+      pending_tiles_->Decrement(false);
+      LIBGAV1_DLOG(ERROR, "Failed to get scratch buffer.");
+      return false;
+    }
     for (int row4x4 = row4x4_start_; row4x4 < row4x4_end_;
          row4x4 += block_width4x4) {
       for (int column4x4 = column4x4_start_; column4x4 < column4x4_end_;
            column4x4 += block_width4x4) {
-        if (!ProcessSuperBlock(row4x4, column4x4, block_width4x4, &sb_buffer,
+        if (!ProcessSuperBlock(row4x4, column4x4, block_width4x4,
+                               scratch_buffer.get(),
                                kProcessingModeParseAndDecode)) {
           pending_tiles_->Decrement(false);
           LIBGAV1_DLOG(ERROR, "Error decoding super block row: %d column: %d",
@@ -463,6 +458,7 @@
         }
       }
     }
+    decoder_scratch_buffer_pool_->Release(std::move(scratch_buffer));
   }
   if (frame_header_.enable_frame_end_update_cdf &&
       number_ == frame_header_.tile_info.context_update_id) {
@@ -489,14 +485,20 @@
   const int block_width4x4 = kNum4x4BlocksWide[SuperBlockSize()];
 
   // Begin parsing.
-  SuperBlockBuffer sb_buffer;
+  std::unique_ptr<DecoderScratchBuffer> scratch_buffer =
+      decoder_scratch_buffer_pool_->Get();
+  if (scratch_buffer == nullptr) {
+    pending_tiles_->Decrement(false);
+    LIBGAV1_DLOG(ERROR, "Failed to get scratch buffer.");
+    return false;
+  }
   for (int row4x4 = row4x4_start_, row_index = 0; row4x4 < row4x4_end_;
        row4x4 += block_width4x4, ++row_index) {
     for (int column4x4 = column4x4_start_, column_index = 0;
          column4x4 < column4x4_end_;
          column4x4 += block_width4x4, ++column_index) {
-      if (!ProcessSuperBlock(row4x4, column4x4, block_width4x4, &sb_buffer,
-                             kProcessingModeParseOnly)) {
+      if (!ProcessSuperBlock(row4x4, column4x4, block_width4x4,
+                             scratch_buffer.get(), kProcessingModeParseOnly)) {
         std::lock_guard<std::mutex> lock(threading_.mutex);
         threading_.abort = true;
         break;
@@ -519,6 +521,7 @@
     std::lock_guard<std::mutex> lock(threading_.mutex);
     if (threading_.abort) break;
   }
+  decoder_scratch_buffer_pool_->Release(std::move(scratch_buffer));
 
   // We are done parsing. We can return here since the calling thread will make
   // sure that it waits for all the superblocks to be decoded.
@@ -574,11 +577,16 @@
 
 void Tile::DecodeSuperBlock(int row_index, int column_index,
                             int block_width4x4) {
-  SuperBlockBuffer sb_buffer;
   const int row4x4 = row4x4_start_ + (row_index * block_width4x4);
   const int column4x4 = column4x4_start_ + (column_index * block_width4x4);
-  const bool ok = ProcessSuperBlock(row4x4, column4x4, block_width4x4,
-                                    &sb_buffer, kProcessingModeDecodeOnly);
+  std::unique_ptr<DecoderScratchBuffer> scratch_buffer =
+      decoder_scratch_buffer_pool_->Get();
+  bool ok = scratch_buffer != nullptr;
+  if (ok) {
+    ok = ProcessSuperBlock(row4x4, column4x4, block_width4x4,
+                           scratch_buffer.get(), kProcessingModeDecodeOnly);
+    decoder_scratch_buffer_pool_->Release(std::move(scratch_buffer));
+  }
   std::unique_lock<std::mutex> lock(threading_.mutex);
   if (ok) {
     threading_.sb_state[row_index][column_index] = kSuperBlockStateDecoded;
@@ -613,12 +621,15 @@
   } else {
     threading_.abort = true;
   }
-  if (--threading_.pending_jobs == 0) {
-    // The lock needs to be unlocked here in this case because the Tile object
-    // could go out of scope as soon as |pending_tiles_->Decrement()| is called.
-    lock.unlock();
+  // Finish using |threading_| before |pending_tiles_->Decrement()| because the
+  // Tile object could go out of scope as soon as |pending_tiles_->Decrement()|
+  // is called.
+  const bool no_pending_jobs = (--threading_.pending_jobs == 0);
+  const bool job_succeeded = !threading_.abort;
+  lock.unlock();
+  if (no_pending_jobs) {
     // We are done parsing and decoding this tile.
-    pending_tiles_->Decrement(!threading_.abort);
+    pending_tiles_->Decrement(job_succeeded);
   }
 }
 
@@ -635,12 +646,9 @@
 
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
-  const BlockSize plane_block_size =
-      kPlaneResidualSize[block.size][subsampling_x_[plane]]
-                        [subsampling_y_[plane]];
-  assert(plane_block_size != kBlockInvalid);
-  const int block_width = kBlockWidthPixels[plane_block_size];
-  const int block_height = kBlockHeightPixels[plane_block_size];
+  const BlockSize plane_size = block.residual_size[GetPlaneType(plane)];
+  const int block_width = kBlockWidthPixels[plane_size];
+  const int block_height = kBlockHeightPixels[plane_size];
 
   int top = 0;
   int left = 0;
@@ -775,13 +783,14 @@
 }
 
 // Section 8.3.2 in the spec, under coeff_base.
-int Tile::GetCoeffBaseContext2D(TransformSize tx_size,
+int Tile::GetCoeffBaseContext2D(const int32_t* const quantized_buffer,
+                                TransformSize tx_size,
                                 int adjusted_tx_width_log2, uint16_t pos) {
   if (pos == 0) return 0;
   const int tx_width = 1 << adjusted_tx_width_log2;
   const int padded_tx_width = tx_width + kQuantizedCoefficientBufferPadding;
-  int32_t* const quantized =
-      &quantized_[PaddedIndex(pos, adjusted_tx_width_log2)];
+  const int32_t* const quantized =
+      &quantized_buffer[PaddedIndex(pos, adjusted_tx_width_log2)];
   const int context = std::min(
       4, DivideBy2(1 + (std::min(quantized[1], 3) +                    // {0, 1}
                         std::min(quantized[padded_tx_width], 3) +      // {1, 0}
@@ -796,13 +805,14 @@
 }
 
 // Section 8.3.2 in the spec, under coeff_base.
-int Tile::GetCoeffBaseContextHorizontal(TransformSize /*tx_size*/,
+int Tile::GetCoeffBaseContextHorizontal(const int32_t* const quantized_buffer,
+                                        TransformSize /*tx_size*/,
                                         int adjusted_tx_width_log2,
                                         uint16_t pos) {
   const int tx_width = 1 << adjusted_tx_width_log2;
   const int padded_tx_width = tx_width + kQuantizedCoefficientBufferPadding;
-  int32_t* const quantized =
-      &quantized_[PaddedIndex(pos, adjusted_tx_width_log2)];
+  const int32_t* const quantized =
+      &quantized_buffer[PaddedIndex(pos, adjusted_tx_width_log2)];
   const int context = std::min(
       4, DivideBy2(1 + (std::min(quantized[1], 3) +                // {0, 1}
                         std::min(quantized[padded_tx_width], 3) +  // {1, 0}
@@ -814,13 +824,14 @@
 }
 
 // Section 8.3.2 in the spec, under coeff_base.
-int Tile::GetCoeffBaseContextVertical(TransformSize /*tx_size*/,
+int Tile::GetCoeffBaseContextVertical(const int32_t* const quantized_buffer,
+                                      TransformSize /*tx_size*/,
                                       int adjusted_tx_width_log2,
                                       uint16_t pos) {
   const int tx_width = 1 << adjusted_tx_width_log2;
   const int padded_tx_width = tx_width + kQuantizedCoefficientBufferPadding;
-  int32_t* const quantized =
-      &quantized_[PaddedIndex(pos, adjusted_tx_width_log2)];
+  const int32_t* const quantized =
+      &quantized_buffer[PaddedIndex(pos, adjusted_tx_width_log2)];
   const int context = std::min(
       4, DivideBy2(1 + (std::min(quantized[1], 3) +                // {0, 1}
                         std::min(quantized[padded_tx_width], 3) +  // {1, 0}
@@ -835,11 +846,12 @@
 }
 
 // Section 8.3.2 in the spec, under coeff_br.
-int Tile::GetCoeffBaseRangeContext2D(int adjusted_tx_width_log2, int pos) {
+int Tile::GetCoeffBaseRangeContext2D(const int32_t* const quantized_buffer,
+                                     int adjusted_tx_width_log2, int pos) {
   const uint8_t tx_width = 1 << adjusted_tx_width_log2;
   const int padded_tx_width = tx_width + kQuantizedCoefficientBufferPadding;
-  int32_t* const quantized =
-      &quantized_[PaddedIndex(pos, adjusted_tx_width_log2)];
+  const int32_t* const quantized =
+      &quantized_buffer[PaddedIndex(pos, adjusted_tx_width_log2)];
   const int context = std::min(
       6, DivideBy2(
              1 +
@@ -856,12 +868,13 @@
 }
 
 // Section 8.3.2 in the spec, under coeff_br.
-int Tile::GetCoeffBaseRangeContextHorizontal(int adjusted_tx_width_log2,
-                                             int pos) {
+int Tile::GetCoeffBaseRangeContextHorizontal(
+    const int32_t* const quantized_buffer, int adjusted_tx_width_log2,
+    int pos) {
   const uint8_t tx_width = 1 << adjusted_tx_width_log2;
   const int padded_tx_width = tx_width + kQuantizedCoefficientBufferPadding;
-  int32_t* const quantized =
-      &quantized_[PaddedIndex(pos, adjusted_tx_width_log2)];
+  const int32_t* const quantized =
+      &quantized_buffer[PaddedIndex(pos, adjusted_tx_width_log2)];
   const int context = std::min(
       6, DivideBy2(
              1 +
@@ -877,12 +890,13 @@
 }
 
 // Section 8.3.2 in the spec, under coeff_br.
-int Tile::GetCoeffBaseRangeContextVertical(int adjusted_tx_width_log2,
-                                           int pos) {
+int Tile::GetCoeffBaseRangeContextVertical(
+    const int32_t* const quantized_buffer, int adjusted_tx_width_log2,
+    int pos) {
   const uint8_t tx_width = 1 << adjusted_tx_width_log2;
   const int padded_tx_width = tx_width + kQuantizedCoefficientBufferPadding;
-  int32_t* const quantized =
-      &quantized_[PaddedIndex(pos, adjusted_tx_width_log2)];
+  const int32_t* const quantized =
+      &quantized_buffer[PaddedIndex(pos, adjusted_tx_width_log2)];
   const int context = std::min(
       6, DivideBy2(
              1 +
@@ -994,18 +1008,20 @@
 
 template <bool is_dc_coefficient>
 bool Tile::ReadSignAndApplyDequantization(
-    const Block& block, const uint16_t* const scan, int i,
-    int adjusted_tx_width_log2, int tx_width, int q_value,
-    const uint8_t* const quantizer_matrix, int shift, int min_value,
-    int max_value, uint16_t* const dc_sign_cdf, int8_t* const dc_category,
-    int* const coefficient_level) {
+    const Block& block, int32_t* const quantized_buffer,
+    const uint16_t* const scan, int i, int adjusted_tx_width_log2, int tx_width,
+    int q_value, const uint8_t* const quantizer_matrix, int shift,
+    int min_value, int max_value, uint16_t* const dc_sign_cdf,
+    int8_t* const dc_category, int* const coefficient_level) {
   int pos = is_dc_coefficient ? 0 : scan[i];
   const int pos_index =
       is_dc_coefficient ? 0 : PaddedIndex(pos, adjusted_tx_width_log2);
-  const bool sign = quantized_[pos_index] != 0 &&
-                    (is_dc_coefficient ? reader_.ReadSymbol(dc_sign_cdf)
-                                       : static_cast<bool>(reader_.ReadBit()));
-  if (quantized_[pos_index] >
+  // If quantized_buffer[pos_index] is zero, then the rest of the function has
+  // no effect.
+  if (quantized_buffer[pos_index] == 0) return true;
+  const bool sign = is_dc_coefficient ? reader_.ReadSymbol(dc_sign_cdf)
+                                      : static_cast<bool>(reader_.ReadBit());
+  if (quantized_buffer[pos_index] >
       kNumQuantizerBaseLevels + kQuantizerCoefficientBaseRange) {
     int length = 0;
     bool golomb_length_bit = false;
@@ -1021,13 +1037,13 @@
     for (int i = length - 2; i >= 0; --i) {
       x = (x << 1) | reader_.ReadBit();
     }
-    quantized_[pos_index] += x - 1;
+    quantized_buffer[pos_index] += x - 1;
   }
-  if (is_dc_coefficient && quantized_[0] > 0) {
+  if (is_dc_coefficient && quantized_buffer[0] > 0) {
     *dc_category = sign ? -1 : 1;
   }
-  quantized_[pos_index] &= 0xfffff;
-  *coefficient_level += quantized_[pos_index];
+  quantized_buffer[pos_index] &= 0xfffff;
+  *coefficient_level += quantized_buffer[pos_index];
   // Apply dequantization. Step 1 of section 7.12.3 in the spec.
   int q = q_value;
   if (quantizer_matrix != nullptr) {
@@ -1036,7 +1052,7 @@
   // The intermediate multiplication can exceed 32 bits, so it has to be
   // performed by promoting one of the values to int64_t.
   int32_t dequantized_value =
-      (static_cast<int64_t>(q) * quantized_[pos_index]) & 0xffffff;
+      (static_cast<int64_t>(q) * quantized_buffer[pos_index]) & 0xffffff;
   dequantized_value >>= shift;
   if (sign) {
     dequantized_value = -dequantized_value;
@@ -1055,13 +1071,11 @@
     pos = MultiplyBy64(row_index) + column_index;
   }
   if (sequence_header_.color_config.bitdepth == 8) {
-    auto* const residual_buffer =
-        reinterpret_cast<int16_t*>(block.sb_buffer->residual);
+    auto* const residual_buffer = reinterpret_cast<int16_t*>(*block.residual);
     residual_buffer[pos] = Clip3(dequantized_value, min_value, max_value);
 #if LIBGAV1_MAX_BITDEPTH >= 10
   } else {
-    auto* const residual_buffer =
-        reinterpret_cast<int32_t*>(block.sb_buffer->residual);
+    auto* const residual_buffer = reinterpret_cast<int32_t*>(*block.residual);
     residual_buffer[pos] = Clip3(dequantized_value, min_value, max_value);
 #endif
   }
@@ -1106,18 +1120,19 @@
   }
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
-  memset(block.sb_buffer->residual, 0, tx_width * tx_height * residual_size_);
+  memset(*block.residual, 0, tx_width * tx_height * residual_size_);
   const int clamped_tx_width = std::min(tx_width, 32);
   const int clamped_tx_height = std::min(tx_height, 32);
   const int padded_tx_width =
       clamped_tx_width + kQuantizedCoefficientBufferPadding;
   const int padded_tx_height =
       clamped_tx_height + kQuantizedCoefficientBufferPadding;
-  // Only the first |padded_tx_width| * |padded_tx_height| values of
-  // |quantized_| will be used by this function. So we simply need to zero out
-  // those values before it is being used (instead of zeroing the entire array).
-  memset(quantized_, 0,
-         padded_tx_width * padded_tx_height * sizeof(quantized_[0]));
+  int32_t* const quantized = block.scratch_buffer->quantized_buffer;
+  // Only the first |padded_tx_width| * |padded_tx_height| values of |quantized|
+  // will be used by this function and the functions to which it is passed into.
+  // So we simply need to zero out those values before it is being used.
+  memset(quantized, 0,
+         padded_tx_width * padded_tx_height * sizeof(quantized[0]));
   if (plane == kPlaneY) {
     ReadTransformType(block, x4, y4, tx_size);
   }
@@ -1125,7 +1140,8 @@
   *tx_type = ComputeTransformType(block, plane, tx_size, x4, y4);
   const int eob_multi_size = kEobMultiSizeLookup[tx_size];
   const PlaneType plane_type = GetPlaneType(plane);
-  context = static_cast<int>(GetTransformClass(*tx_type) != kTransformClass2D);
+  const TransformClass tx_class = GetTransformClass(*tx_type);
+  context = static_cast<int>(tx_class != kTransformClass2D);
   uint16_t* cdf;
   switch (eob_multi_size) {
     case 0:
@@ -1168,23 +1184,22 @@
       }
     }
   }
-  const TransformClass tx_class = GetTransformClass(*tx_type);
   const uint16_t* scan = kScan[tx_class][tx_size];
   const TransformSize adjusted_tx_size = kAdjustedTransformSize[tx_size];
   const int adjusted_tx_width_log2 = kTransformWidthLog2[adjusted_tx_size];
   // Lookup used to call the right variant of GetCoeffBaseContext*() based on
   // the transform class.
   static constexpr int (Tile::*kGetCoeffBaseContextFunc[])(
-      TransformSize, int, uint16_t) = {&Tile::GetCoeffBaseContext2D,
-                                       &Tile::GetCoeffBaseContextHorizontal,
-                                       &Tile::GetCoeffBaseContextVertical};
+      const int32_t*, TransformSize, int, uint16_t) = {
+      &Tile::GetCoeffBaseContext2D, &Tile::GetCoeffBaseContextHorizontal,
+      &Tile::GetCoeffBaseContextVertical};
   auto get_coeff_base_context_func = kGetCoeffBaseContextFunc[tx_class];
   // Lookup used to call the right variant of GetCoeffBaseRangeContext*() based
   // on the transform class.
-  static constexpr int (Tile::*kGetCoeffBaseRangeContextFunc[])(int, int) = {
-      &Tile::GetCoeffBaseRangeContext2D,
-      &Tile::GetCoeffBaseRangeContextHorizontal,
-      &Tile::GetCoeffBaseRangeContextVertical};
+  static constexpr int (Tile::*kGetCoeffBaseRangeContextFunc[])(
+      const int32_t*, int, int) = {&Tile::GetCoeffBaseRangeContext2D,
+                                   &Tile::GetCoeffBaseRangeContextHorizontal,
+                                   &Tile::GetCoeffBaseRangeContextVertical};
   auto get_coeff_base_range_context_func =
       kGetCoeffBaseRangeContextFunc[tx_class];
   const int clamped_tx_size_context = std::min(tx_size_context, 3);
@@ -1200,15 +1215,15 @@
     if (level > kNumQuantizerBaseLevels) {
       level += ReadCoeffBaseRange(clamped_tx_size_context,
                                   (this->*get_coeff_base_range_context_func)(
-                                      adjusted_tx_width_log2, pos),
+                                      quantized, adjusted_tx_width_log2, pos),
                                   plane_type);
     }
-    quantized_[PaddedIndex(pos, adjusted_tx_width_log2)] = level;
+    quantized[PaddedIndex(pos, adjusted_tx_width_log2)] = level;
   }
   // Read all the other coefficients.
   for (int i = eob - 2; i >= 0; --i) {
     const uint16_t pos = scan[i];
-    context = (this->*get_coeff_base_context_func)(tx_size,
+    context = (this->*get_coeff_base_context_func)(quantized, tx_size,
                                                    adjusted_tx_width_log2, pos);
     int level = reader_.ReadSymbol<kCoeffBaseSymbolCount>(
         symbol_decoder_context_
@@ -1216,10 +1231,10 @@
     if (level > kNumQuantizerBaseLevels) {
       level += ReadCoeffBaseRange(clamped_tx_size_context,
                                   (this->*get_coeff_base_range_context_func)(
-                                      adjusted_tx_width_log2, pos),
+                                      quantized, adjusted_tx_width_log2, pos),
                                   plane_type);
     }
-    quantized_[PaddedIndex(pos, adjusted_tx_width_log2)] = level;
+    quantized[PaddedIndex(pos, adjusted_tx_width_log2)] = level;
   }
   const int min_value = -(1 << (7 + sequence_header_.color_config.bitdepth));
   const int max_value = (1 << (7 + sequence_header_.color_config.bitdepth)) - 1;
@@ -1239,29 +1254,29 @@
   int coefficient_level = 0;
   int8_t dc_category = 0;
   uint16_t* const dc_sign_cdf =
-      (quantized_[0] != 0)
+      (quantized[0] != 0)
           ? symbol_decoder_context_.dc_sign_cdf[plane_type][GetDcSignContext(
                 x4, y4, w4, h4, plane)]
           : nullptr;
   assert(scan[0] == 0);
   if (!ReadSignAndApplyDequantization</*is_dc_coefficient=*/true>(
-          block, scan, 0, adjusted_tx_width_log2, tx_width, dc_q_value,
-          quantizer_matrix, shift, min_value, max_value, dc_sign_cdf,
-          &dc_category, &coefficient_level)) {
+          block, quantized, scan, 0, adjusted_tx_width_log2, tx_width,
+          dc_q_value, quantizer_matrix, shift, min_value, max_value,
+          dc_sign_cdf, &dc_category, &coefficient_level)) {
     return -1;
   }
   for (int i = 1; i < eob; ++i) {
     if (!ReadSignAndApplyDequantization</*is_dc_coefficient=*/false>(
-            block, scan, i, adjusted_tx_width_log2, tx_width, ac_q_value,
-            quantizer_matrix, shift, min_value, max_value, nullptr, nullptr,
-            &coefficient_level)) {
+            block, quantized, scan, i, adjusted_tx_width_log2, tx_width,
+            ac_q_value, quantizer_matrix, shift, min_value, max_value, nullptr,
+            nullptr, &coefficient_level)) {
       return -1;
     }
   }
   SetEntropyContexts(x4, y4, w4, h4, plane, std::min(4, coefficient_level),
                      dc_category);
   if (split_parse_and_decode_) {
-    block.sb_buffer->residual += tx_width * tx_height * residual_size_;
+    *block.residual += tx_width * tx_height * residual_size_;
   }
   return eob;
 }
@@ -1317,15 +1332,15 @@
       if (sequence_header_.color_config.bitdepth == 8) {
         IntraPrediction<uint8_t>(
             block, plane, start_x, start_y, has_left, has_top,
-            block.sb_buffer->block_decoded[plane][tr_row4x4][tr_column4x4],
-            block.sb_buffer->block_decoded[plane][bl_row4x4][bl_column4x4],
+            block.scratch_buffer->block_decoded[plane][tr_row4x4][tr_column4x4],
+            block.scratch_buffer->block_decoded[plane][bl_row4x4][bl_column4x4],
             mode, tx_size);
 #if LIBGAV1_MAX_BITDEPTH >= 10
       } else {
         IntraPrediction<uint16_t>(
             block, plane, start_x, start_y, has_left, has_top,
-            block.sb_buffer->block_decoded[plane][tr_row4x4][tr_column4x4],
-            block.sb_buffer->block_decoded[plane][bl_row4x4][bl_column4x4],
+            block.scratch_buffer->block_decoded[plane][tr_row4x4][tr_column4x4],
+            block.scratch_buffer->block_decoded[plane][bl_row4x4][bl_column4x4],
             mode, tx_size);
 #endif
       }
@@ -1346,10 +1361,12 @@
           start_x + MultiplyBy4(step_x);
       block.bp->prediction_parameters->max_luma_height =
           start_y + MultiplyBy4(step_y);
-      block.sb_buffer->cfl_luma_buffer_valid = false;
+      block.scratch_buffer->cfl_luma_buffer_valid = false;
     }
   }
   if (!bp.skip) {
+    const int sb_row_index = SuperBlockRowIndex(block.row4x4);
+    const int sb_column_index = SuperBlockColumnIndex(block.column4x4);
     switch (mode) {
       case kProcessingModeParseAndDecode: {
         TransformType tx_type;
@@ -1365,29 +1382,31 @@
         const int16_t non_zero_coeff_count = ReadTransformCoefficients(
             block, plane, start_x, start_y, tx_size, &tx_type);
         if (non_zero_coeff_count < 0) return false;
-        block.sb_buffer->transform_parameters->Push(non_zero_coeff_count,
-                                                    tx_type);
+        residual_buffer_threaded_[sb_row_index][sb_column_index]
+            ->transform_parameters()
+            ->Push(non_zero_coeff_count, tx_type);
         break;
       }
       case kProcessingModeDecodeOnly: {
-        ReconstructBlock(
-            block, plane, start_x, start_y, tx_size,
-            block.sb_buffer->transform_parameters->Type(),
-            block.sb_buffer->transform_parameters->NonZeroCoeffCount());
-        block.sb_buffer->transform_parameters->Pop();
+        TransformParameterQueue& tx_params =
+            *residual_buffer_threaded_[sb_row_index][sb_column_index]
+                 ->transform_parameters();
+        ReconstructBlock(block, plane, start_x, start_y, tx_size,
+                         tx_params.Type(), tx_params.NonZeroCoeffCount());
+        tx_params.Pop();
         break;
       }
     }
   }
   if (do_decode) {
     bool* block_decoded =
-        &block.sb_buffer
+        &block.scratch_buffer
              ->block_decoded[plane][(sub_block_row4x4 >> subsampling_y) + 1]
                             [(sub_block_column4x4 >> subsampling_x) + 1];
     for (int i = 0; i < step_y; ++i) {
       static_assert(sizeof(bool) == 1, "");
       memset(block_decoded, 1, step_x);
-      block_decoded += kBlockDecodedStride;
+      block_decoded += DecoderScratchBuffer::kBlockDecodedStride;
     }
   }
   return true;
@@ -1460,8 +1479,8 @@
   if (sequence_header_.color_config.bitdepth == 8) {
     Reconstruct(dsp_, tx_type, tx_size,
                 frame_header_.segmentation.lossless[block.bp->segment_id],
-                reinterpret_cast<int16_t*>(block.sb_buffer->residual), start_x,
-                start_y, &buffer_[plane], non_zero_coeff_count);
+                reinterpret_cast<int16_t*>(*block.residual), start_x, start_y,
+                &buffer_[plane], non_zero_coeff_count);
 #if LIBGAV1_MAX_BITDEPTH >= 10
   } else {
     Array2DView<uint16_t> buffer(
@@ -1469,12 +1488,12 @@
         reinterpret_cast<uint16_t*>(&buffer_[plane][0][0]));
     Reconstruct(dsp_, tx_type, tx_size,
                 frame_header_.segmentation.lossless[block.bp->segment_id],
-                reinterpret_cast<int32_t*>(block.sb_buffer->residual), start_x,
-                start_y, &buffer, non_zero_coeff_count);
+                reinterpret_cast<int32_t*>(*block.residual), start_x, start_y,
+                &buffer, non_zero_coeff_count);
 #endif
   }
   if (split_parse_and_decode_) {
-    block.sb_buffer->residual +=
+    *block.residual +=
         kTransformWidth[tx_size] * kTransformHeight[tx_size] * residual_size_;
   }
 }
@@ -1485,17 +1504,17 @@
   const BlockSize size_chunk4x4 =
       (width_chunks > 1 || height_chunks > 1) ? kBlock64x64 : block.size;
   const BlockParameters& bp = *block.bp;
-  const TransformSize y_tx_size =
-      GetTransformSize(frame_header_.segmentation.lossless[bp.segment_id],
-                       block.size, kPlaneY, bp.transform_size, 0, 0);
   for (int chunk_y = 0; chunk_y < height_chunks; ++chunk_y) {
     for (int chunk_x = 0; chunk_x < width_chunks; ++chunk_x) {
       for (int plane = 0; plane < (block.HasChroma() ? PlaneCount() : 1);
            ++plane) {
         const int subsampling_x = subsampling_x_[plane];
         const int subsampling_y = subsampling_y_[plane];
+        // For Y Plane, when lossless is true |bp.transform_size| is always
+        // kTransformSize4x4. So we can simply use |bp.transform_size| here as
+        // the Y plane's transform size (part of Section 5.11.37 in the spec).
         const TransformSize tx_size =
-            (plane == kPlaneY) ? y_tx_size : bp.uv_transform_size;
+            (plane == kPlaneY) ? bp.transform_size : bp.uv_transform_size;
         const BlockSize plane_size =
             kPlaneResidualSize[size_chunk4x4][subsampling_x][subsampling_y];
         assert(plane_size != kBlockInvalid);
@@ -1672,7 +1691,7 @@
   }
 }
 
-bool Tile::ComputePrediction(const Block& block) {
+void Tile::ComputePrediction(const Block& block) {
   const int mask =
       (1 << (4 + static_cast<int>(sequence_header_.use_128x128_superblock))) -
       1;
@@ -1689,8 +1708,7 @@
     const int8_t subsampling_x = subsampling_x_[plane];
     const int8_t subsampling_y = subsampling_y_[plane];
     const BlockSize plane_size =
-        kPlaneResidualSize[block.size][subsampling_x][subsampling_y];
-    assert(plane_size != kBlockInvalid);
+        block.residual_size[GetPlaneType(static_cast<Plane>(plane))];
     const int block_width4x4 = kNum4x4BlocksWide[plane_size];
     const int block_height4x4 = kNum4x4BlocksHigh[plane_size];
     const int block_width = MultiplyBy4(block_width4x4);
@@ -1715,8 +1733,8 @@
       if (sequence_header_.color_config.bitdepth == 8) {
         IntraPrediction<uint8_t>(
             block, static_cast<Plane>(plane), base_x, base_y, has_left, has_top,
-            block.sb_buffer->block_decoded[plane][tr_row4x4][tr_column4x4],
-            block.sb_buffer->block_decoded[plane][bl_row4x4][bl_column4x4],
+            block.scratch_buffer->block_decoded[plane][tr_row4x4][tr_column4x4],
+            block.scratch_buffer->block_decoded[plane][bl_row4x4][bl_column4x4],
             kInterIntraToIntraMode[block.bp->prediction_parameters
                                        ->inter_intra_mode],
             tx_size);
@@ -1724,8 +1742,8 @@
       } else {
         IntraPrediction<uint16_t>(
             block, static_cast<Plane>(plane), base_x, base_y, has_left, has_top,
-            block.sb_buffer->block_decoded[plane][tr_row4x4][tr_column4x4],
-            block.sb_buffer->block_decoded[plane][bl_row4x4][bl_column4x4],
+            block.scratch_buffer->block_decoded[plane][tr_row4x4][tr_column4x4],
+            block.scratch_buffer->block_decoded[plane][bl_row4x4][bl_column4x4],
             kInterIntraToIntraMode[block.bp->prediction_parameters
                                        ->inter_intra_mode],
             tx_size);
@@ -1761,17 +1779,14 @@
       }
       for (int r = 0, y = 0; y < block_height; y += prediction_height, ++r) {
         for (int c = 0, x = 0; x < block_width; x += prediction_width, ++c) {
-          if (!InterPrediction(block, static_cast<Plane>(plane), base_x + x,
-                               base_y + y, prediction_width, prediction_height,
-                               candidate_row + r, candidate_column + c,
-                               &is_local_valid, &local_warp_params)) {
-            return false;
-          }
+          InterPrediction(block, static_cast<Plane>(plane), base_x + x,
+                          base_y + y, prediction_width, prediction_height,
+                          candidate_row + r, candidate_column + c,
+                          &is_local_valid, &local_warp_params);
         }
       }
     }
   }
-  return true;
 }
 
 void Tile::PopulateDeblockFilterLevel(const Block& block) {
@@ -1792,7 +1807,8 @@
 
 bool Tile::ProcessBlock(int row4x4, int column4x4, BlockSize block_size,
                         ParameterTree* const tree,
-                        SuperBlockBuffer* const sb_buffer) {
+                        DecoderScratchBuffer* const scratch_buffer,
+                        ResidualPtr* residual) {
   // Do not process the block if the starting point is beyond the visible frame.
   // This is equivalent to the has_row/has_column check in the
   // decode_partition() section of the spec when partition equals
@@ -1801,7 +1817,7 @@
       column4x4 >= frame_header_.columns4x4) {
     return true;
   }
-  Block block(*this, row4x4, column4x4, block_size, sb_buffer,
+  Block block(*this, row4x4, column4x4, block_size, scratch_buffer, residual,
               tree->parameters());
   block.bp->size = block_size;
   block_parameters_holder_.FillCache(row4x4, column4x4, block_size,
@@ -1816,19 +1832,19 @@
   if (!ReadPaletteTokens(block)) return false;
   DecodeTransformSize(block);
   BlockParameters& bp = *block.bp;
-  bp.uv_transform_size = GetTransformSize(
-      frame_header_.segmentation.lossless[bp.segment_id], block.size, kPlaneU,
-      bp.transform_size, subsampling_x_[kPlaneU], subsampling_y_[kPlaneU]);
+  // Part of Section 5.11.37 in the spec (implemented as a simple lookup).
+  bp.uv_transform_size =
+      frame_header_.segmentation.lossless[bp.segment_id]
+          ? kTransformSize4x4
+          : kUVTransformSize[block.residual_size[kPlaneTypeUV]];
   if (bp.skip) ResetEntropyContext(block);
   const int block_width4x4 = kNum4x4BlocksWide[block_size];
   const int block_height4x4 = kNum4x4BlocksHigh[block_size];
   if (split_parse_and_decode_) {
     if (!Residual(block, kProcessingModeParseOnly)) return false;
   } else {
-    if (!ComputePrediction(block) ||
-        !Residual(block, kProcessingModeParseAndDecode)) {
-      return false;
-    }
+    ComputePrediction(block);
+    if (!Residual(block, kProcessingModeParseAndDecode)) return false;
   }
   // If frame_header_.segmentation.enabled is false, bp.segment_id is 0 for all
   // blocks. We don't need to call save bp.segment_id in the current frame
@@ -1858,7 +1874,8 @@
 }
 
 bool Tile::DecodeBlock(ParameterTree* const tree,
-                       SuperBlockBuffer* const sb_buffer) {
+                       DecoderScratchBuffer* const scratch_buffer,
+                       ResidualPtr* residual) {
   const int row4x4 = tree->row4x4();
   const int column4x4 = tree->column4x4();
   if (row4x4 >= frame_header_.rows4x4 ||
@@ -1866,12 +1883,10 @@
     return true;
   }
   const BlockSize block_size = tree->block_size();
-  Block block(*this, row4x4, column4x4, block_size, sb_buffer,
+  Block block(*this, row4x4, column4x4, block_size, scratch_buffer, residual,
               tree->parameters());
-  if (!ComputePrediction(block) ||
-      !Residual(block, kProcessingModeDecodeOnly)) {
-    return false;
-  }
+  ComputePrediction(block);
+  if (!Residual(block, kProcessingModeDecodeOnly)) return false;
   if (!build_bit_mask_when_parsing_) {
     BuildBitMask(row4x4, column4x4, block_size);
   }
@@ -1882,7 +1897,8 @@
 
 bool Tile::ProcessPartition(int row4x4_start, int column4x4_start,
                             ParameterTree* const root,
-                            SuperBlockBuffer* const sb_buffer) {
+                            DecoderScratchBuffer* const scratch_buffer,
+                            ResidualPtr* residual) {
   Stack<ParameterTree*, kDfsStackSize> stack;
 
   // Set up the first iteration.
@@ -1942,7 +1958,8 @@
     }
     switch (partition) {
       case kPartitionNone:
-        if (!ProcessBlock(row4x4, column4x4, sub_size, node, sb_buffer)) {
+        if (!ProcessBlock(row4x4, column4x4, sub_size, node, scratch_buffer,
+                          residual)) {
           return false;
         }
         break;
@@ -1969,7 +1986,8 @@
           // null.
           if (child == nullptr) break;
           if (!ProcessBlock(child->row4x4(), child->column4x4(),
-                            child->block_size(), child, sb_buffer)) {
+                            child->block_size(), child, scratch_buffer,
+                            residual)) {
             return false;
           }
         }
@@ -2007,10 +2025,11 @@
   }
 }
 
-void Tile::ClearBlockDecoded(SuperBlockBuffer* const sb_buffer, int row4x4,
-                             int column4x4) {
+void Tile::ClearBlockDecoded(DecoderScratchBuffer* const scratch_buffer,
+                             int row4x4, int column4x4) {
   // Set everything to false.
-  memset(sb_buffer->block_decoded, 0, sizeof(sb_buffer->block_decoded));
+  memset(scratch_buffer->block_decoded, 0,
+         sizeof(scratch_buffer->block_decoded));
   // Set specific edge cases to true.
   const int sb_size4 = sequence_header_.use_128x128_superblock ? 32 : 16;
   for (int plane = 0; plane < PlaneCount(); ++plane) {
@@ -2026,7 +2045,7 @@
     // }
     const int num_elements =
         std::min((sb_size4 >> subsampling_x_[plane]) + 1, sb_width4) + 1;
-    memset(&sb_buffer->block_decoded[plane][0][0], 1, num_elements);
+    memset(&scratch_buffer->block_decoded[plane][0][0], 1, num_elements);
     // The for loop is equivalent to the following lines in the spec:
     // for ( y = -1; y <= ( sbSize4 >> subY ); y++ )
     //   if ( x < 0 && y < sbHeight4 )
@@ -2036,13 +2055,13 @@
     // BlockDecoded[plane][sbSize4 >> subY][-1] = 0
     for (int y = -1; y < std::min((sb_size4 >> subsampling_y), sb_height4);
          ++y) {
-      sb_buffer->block_decoded[plane][y + 1][0] = true;
+      scratch_buffer->block_decoded[plane][y + 1][0] = true;
     }
   }
 }
 
 bool Tile::ProcessSuperBlock(int row4x4, int column4x4, int block_width4x4,
-                             SuperBlockBuffer* const sb_buffer,
+                             DecoderScratchBuffer* const scratch_buffer,
                              ProcessingMode mode) {
   const bool parsing =
       mode == kProcessingModeParseOnly || mode == kProcessingModeParseAndDecode;
@@ -2053,7 +2072,7 @@
     ResetCdef(row4x4, column4x4);
   }
   if (decoding) {
-    ClearBlockDecoded(sb_buffer, row4x4, column4x4);
+    ClearBlockDecoded(scratch_buffer, row4x4, column4x4);
   }
   const BlockSize block_size = SuperBlockSize();
   if (parsing) {
@@ -2062,18 +2081,18 @@
   const int row = row4x4 / block_width4x4;
   const int column = column4x4 / block_width4x4;
   if (parsing && decoding) {
-    sb_buffer->residual = residual_buffer_.get();
+    uint8_t* residual_buffer = residual_buffer_.get();
     if (!ProcessPartition(row4x4, column4x4,
                           block_parameters_holder_.Tree(row, column),
-                          sb_buffer)) {
+                          scratch_buffer, &residual_buffer)) {
       LIBGAV1_DLOG(ERROR, "Error decoding partition row: %d column: %d", row4x4,
                    column4x4);
       return false;
     }
     return true;
   }
-  const int sb_row_index = (row4x4 - row4x4_start_) / block_width4x4;
-  const int sb_column_index = (column4x4 - column4x4_start_) / block_width4x4;
+  const int sb_row_index = SuperBlockRowIndex(row4x4);
+  const int sb_column_index = SuperBlockColumnIndex(column4x4);
   if (parsing) {
     residual_buffer_threaded_[sb_row_index][sb_column_index] =
         residual_buffer_pool_->Get();
@@ -2081,26 +2100,20 @@
       LIBGAV1_DLOG(ERROR, "Failed to get residual buffer.");
       return false;
     }
-    sb_buffer->residual =
+    uint8_t* residual_buffer =
         residual_buffer_threaded_[sb_row_index][sb_column_index]->buffer();
-    sb_buffer->transform_parameters =
-        residual_buffer_threaded_[sb_row_index][sb_column_index]
-            ->transform_parameters();
     if (!ProcessPartition(row4x4, column4x4,
                           block_parameters_holder_.Tree(row, column),
-                          sb_buffer)) {
+                          scratch_buffer, &residual_buffer)) {
       LIBGAV1_DLOG(ERROR, "Error parsing partition row: %d column: %d", row4x4,
                    column4x4);
       return false;
     }
   } else {
-    sb_buffer->residual =
+    uint8_t* residual_buffer =
         residual_buffer_threaded_[sb_row_index][sb_column_index]->buffer();
-    sb_buffer->transform_parameters =
-        residual_buffer_threaded_[sb_row_index][sb_column_index]
-            ->transform_parameters();
     if (!DecodeSuperBlock(block_parameters_holder_.Tree(row, column),
-                          sb_buffer)) {
+                          scratch_buffer, &residual_buffer)) {
       LIBGAV1_DLOG(ERROR, "Error decoding superblock row: %d column: %d",
                    row4x4, column4x4);
       return false;
@@ -2112,7 +2125,8 @@
 }
 
 bool Tile::DecodeSuperBlock(ParameterTree* const tree,
-                            SuperBlockBuffer* const sb_buffer) {
+                            DecoderScratchBuffer* const scratch_buffer,
+                            ResidualPtr* residual) {
   Stack<ParameterTree*, kDfsStackSize> stack;
   stack.Push(tree);
   while (!stack.Empty()) {
@@ -2124,7 +2138,7 @@
       }
       continue;
     }
-    if (!DecodeBlock(node, sb_buffer)) {
+    if (!DecodeBlock(node, scratch_buffer, residual)) {
       LIBGAV1_DLOG(ERROR, "Error decoding block row: %d column: %d",
                    node->row4x4(), node->column4x4());
       return false;
diff --git a/libgav1/src/utils/array_2d.h b/libgav1/src/utils/array_2d.h
index c054178..08b68c3 100644
--- a/libgav1/src/utils/array_2d.h
+++ b/libgav1/src/utils/array_2d.h
@@ -5,6 +5,7 @@
 #include <cstring>
 #include <memory>
 #include <new>
+#include <type_traits>
 
 #include "src/utils/compiler_attributes.h"
 
@@ -61,7 +62,9 @@
   LIBGAV1_MUST_USE_RESULT bool Reset(int rows, int columns,
                                      bool zero_initialize = true) {
     const size_t size = rows * columns;
-    if (size_ < size) {
+    // If T is not a trivial type, we should always reallocate the data_
+    // buffer, so that the destructors of any existing objects are invoked.
+    if (!std::is_trivial<T>::value || size_ < size) {
       // Note: This invokes the global operator new if T is a non-class type,
       // such as integer or enum types, or a class type that is not derived
       // from libgav1::Allocable, such as std::unique_ptr. If we enforce a
@@ -79,7 +82,10 @@
       }
       size_ = size;
     } else if (zero_initialize) {
-      memset(data_.get(), 0, sizeof(T) * size);
+      // Cast the data_ pointer to void* to avoid the GCC -Wclass-memaccess
+      // warning. The memset is safe because T is a trivial type.
+      void* dest = data_.get();
+      memset(dest, 0, sizeof(T) * size);
     }
     data_view_.Reset(rows, columns, data_.get());
     return true;
diff --git a/libgav1/src/utils/bit_reader.h b/libgav1/src/utils/bit_reader.h
index 3b67ad1..172107c 100644
--- a/libgav1/src/utils/bit_reader.h
+++ b/libgav1/src/utils/bit_reader.h
@@ -10,6 +10,8 @@
   virtual ~BitReader() = default;
 
   virtual int ReadBit() = 0;
+  // |num_bits| has to be <= 32. The function returns a value in the range [0,
+  // 2^num_bits - 1] (inclusive) on success and -1 on failure.
   virtual int64_t ReadLiteral(int num_bits) = 0;
 
   bool DecodeSignedSubexpWithReference(int low, int high, int reference,
diff --git a/libgav1/src/utils/constants.h b/libgav1/src/utils/constants.h
index 5799a91..cacd8c8 100644
--- a/libgav1/src/utils/constants.h
+++ b/libgav1/src/utils/constants.h
@@ -42,11 +42,22 @@
   kMinPaletteSize = 2,
   kMaxPaletteSquare = 64,
   kBorderPixels = 64,
+  // Although the left and right borders of a frame start with kBorderPixels,
+  // they may change if YuvBuffer::ShiftBuffer() is called. These constants
+  // are the minimum left and right border sizes in pixels as an extension of
+  // the frame boundary. The minimum border sizes are derived from the
+  // following requirements:
+  // - Warp_C() may read up to 13 pixels before or after a row.
+  // - Warp_NEON() may read up to 13 pixels before a row. It may read up to 14
+  //   pixels after a row, but the value of the last read pixel is not used.
+  kMinLeftBorderPixels = 13,
+  kMinRightBorderPixels = 13,
   kWarpedModelPrecisionBits = 16,
   kMaxRefMvStackSize = 8,
   kExtraWeightForNearestMvs = 640,
   kMaxLeastSquaresSamples = 8,
   kMaxSuperBlockSizeInPixels = 128,
+  kMaxSuperBlockSizeSquareInPixels = 128 * 128,
   kNum4x4InLoopFilterMaskUnit = 16,
   kRestorationUnitOffset = 8,
   // 2 pixel padding for 5x5 box sum on each side.
@@ -103,6 +114,8 @@
   kMaxFrameDistance = 31,
   kReferenceFrameScalePrecision = 14,
   kNumWienerCoefficients = 3,
+  // Maximum number of threads that the library will ever create.
+  kMaxThreads = 32,
 };  // anonymous enum
 
 enum FrameType : uint8_t {
diff --git a/libgav1/src/utils/entropy_decoder.cc b/libgav1/src/utils/entropy_decoder.cc
index ff6db39..d8fdc20 100644
--- a/libgav1/src/utils/entropy_decoder.cc
+++ b/libgav1/src/utils/entropy_decoder.cc
@@ -25,6 +25,56 @@
          (kMinimumProbabilityPerSymbol * (symbol_count - index));
 }
 
+void UpdateCdf(uint16_t* const cdf, int symbol_count, int symbol) {
+  const uint16_t count = cdf[symbol_count];
+  // rate is computed in the spec as:
+  //  3 + ( cdf[N] > 15 ) + ( cdf[N] > 31 ) + Min(FloorLog2(N), 2)
+  // In this case cdf[N] is |count|.
+  // Min(FloorLog2(N), 2) is 1 for symbol_count == {2, 3} and 2 for all
+  // symbol_count > 3. So the equation becomes:
+  //  4 + (count > 15) + (count > 31) + (symbol_count > 3).
+  // Note that the largest value for count is 32 (it is not incremented beyond
+  // 32). So using that information:
+  //  count >> 4 is 0 for count from 0 to 15.
+  //  count >> 4 is 1 for count from 16 to 31.
+  //  count >> 4 is 2 for count == 31.
+  // Now, the equation becomes:
+  //  4 + (count >> 4) + (symbol_count > 3).
+  // Since (count >> 4) can only be 0 or 1 or 2, the addition can be replaced
+  // with bitwise or. So the final equation is:
+  // (4 | (count >> 4)) + (symbol_count > 3).
+  const int rate = (4 | (count >> 4)) + static_cast<int>(symbol_count > 3);
+  // Hints for further optimizations:
+  //
+  // 1. clang can vectorize this for loop with width 4, even though the loop
+  // contains an if-else statement. Therefore, it may be advantageous to use
+  // "i < symbol_count" as the loop condition when symbol_count is 8, 12, or 16
+  // (a multiple of 4 that's not too small).
+  //
+  // 2. The for loop can be rewritten in the following form, which would enable
+  // clang to vectorize the loop with width 8:
+  //
+  //   const int mask = (1 << rate) - 1;
+  //   for (int i = 0; i < symbol_count - 1; ++i) {
+  //     const uint16_t a = (i < symbol) ? kCdfMaxProbability : mask;
+  //     cdf[i] += static_cast<int16_t>(a - cdf[i]) >> rate;
+  //   }
+  //
+  // The subtraction (a - cdf[i]) relies on the overflow semantics of unsigned
+  // integer arithmetic. The result of the unsigned subtraction is cast to a
+  // signed integer and right-shifted. This requires the right shift of a
+  // signed integer be an arithmetic shift, which is true for clang, gcc, and
+  // Visual C++.
+  for (int i = 0; i < symbol_count - 1; ++i) {
+    if (i < symbol) {
+      cdf[i] += (libgav1::kCdfMaxProbability - cdf[i]) >> rate;
+    } else {
+      cdf[i] -= cdf[i] >> rate;
+    }
+  }
+  cdf[symbol_count] += static_cast<uint16_t>(count < 32);
+}
+
 }  // namespace
 
 namespace libgav1 {
@@ -66,7 +116,7 @@
 }
 
 int64_t DaalaBitReader::ReadLiteral(int num_bits) {
-  if (num_bits > 32) return -1;
+  assert(num_bits <= 32);
   uint32_t literal = 0;
   for (int bit = num_bits - 1; bit >= 0; --bit) {
     literal |= static_cast<uint32_t>(ReadBit()) << bit;
@@ -100,7 +150,7 @@
     // Since (count >> 4) can only be 0 or 1 or 2, the addition can be replaced
     // with bitwise or. So the final equation is:
     //  4 | (count >> 4).
-    const uint8_t rate = 4 | (count >> 4);
+    const int rate = 4 | (count >> 4);
     if (symbol) {
       cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
     } else {
@@ -115,6 +165,18 @@
   return ReadSymbolImpl(cdf) != 0;
 }
 
+template <int symbol_count>
+int DaalaBitReader::ReadSymbol(uint16_t* const cdf) {
+  static_assert(symbol_count >= 3 && symbol_count <= 16, "");
+  const int symbol = (symbol_count <= 13)
+                         ? ReadSymbolImpl(cdf, symbol_count)
+                         : ReadSymbolImplBinarySearch(cdf, symbol_count);
+  if (allow_update_cdf_) {
+    UpdateCdf(cdf, symbol_count, symbol);
+  }
+  return symbol;
+}
+
 int DaalaBitReader::ReadSymbolImpl(const uint16_t* const cdf,
                                    int symbol_count) {
   assert(cdf[symbol_count - 1] == 0);
@@ -218,55 +280,14 @@
   if (bits_ < 0) PopulateBits();
 }
 
-void DaalaBitReader::UpdateCdf(uint16_t* const cdf, int symbol_count,
-                               int symbol) {
-  const uint16_t count = cdf[symbol_count];
-  // rate is computed in the spec as:
-  //  3 + ( cdf[N] > 15 ) + ( cdf[N] > 31 ) + Min(FloorLog2(N), 2)
-  // In this case cdf[N] is |count|.
-  // Min(FloorLog2(N), 2) is 1 for symbol_count == {2, 3} and 2 for all
-  // symbol_count > 3. So the equation becomes:
-  //  4 + (count > 15) + (count > 31) + (symbol_count > 3).
-  // Note that the largest value for count is 32 (it is not incremented beyond
-  // 32). So using that information:
-  //  count >> 4 is 0 for count from 0 to 15.
-  //  count >> 4 is 1 for count from 16 to 31.
-  //  count >> 4 is 2 for count == 31.
-  // Now, the equation becomes:
-  //  4 + (count >> 4) + (symbol_count > 3).
-  // Since (count >> 4) can only be 0 or 1 or 2, the addition can be replaced
-  // with bitwise or. So the final equation is:
-  // (4 | (count >> 4)) + (symbol_count > 3).
-  const int rate = (4 | (count >> 4)) + static_cast<int>(symbol_count > 3);
-  // Hints for further optimizations:
-  //
-  // 1. clang can vectorize this for loop with width 4, even though the loop
-  // contains an if-else statement. Therefore, it may be advantageous to use
-  // "i < symbol_count" as the loop condition when symbol_count is 8, 12, or 16
-  // (a multiple of 4 that's not too small).
-  //
-  // 2. The for loop can be rewritten in the following form, which would enable
-  // clang to vectorize the loop with width 8:
-  //
-  //   const int mask = (1 << rate) - 1;
-  //   for (int i = 0; i < symbol_count - 1; ++i) {
-  //     const uint16_t a = (i < symbol) ? kCdfMaxProbability : mask;
-  //     cdf[i] += static_cast<int16_t>(a - cdf[i]) >> rate;
-  //   }
-  //
-  // The subtraction (a - cdf[i]) relies on the overflow semantics of unsigned
-  // integer arithmetic. The result of the unsigned subtraction is cast to a
-  // signed integer and right-shifted. This requires the right shift of a
-  // signed integer be an arithmetic shift, which is true for clang, gcc, and
-  // Visual C++.
-  for (int i = 0; i < symbol_count - 1; ++i) {
-    if (i < symbol) {
-      cdf[i] += (libgav1::kCdfMaxProbability - cdf[i]) >> rate;
-    } else {
-      cdf[i] -= cdf[i] >> rate;
-    }
-  }
-  cdf[symbol_count] += static_cast<uint16_t>(count < 32);
-}
+// Explicit instantiations.
+template int DaalaBitReader::ReadSymbol<3>(uint16_t* cdf);
+template int DaalaBitReader::ReadSymbol<4>(uint16_t* cdf);
+template int DaalaBitReader::ReadSymbol<5>(uint16_t* cdf);
+template int DaalaBitReader::ReadSymbol<7>(uint16_t* cdf);
+template int DaalaBitReader::ReadSymbol<8>(uint16_t* cdf);
+template int DaalaBitReader::ReadSymbol<11>(uint16_t* cdf);
+template int DaalaBitReader::ReadSymbol<13>(uint16_t* cdf);
+template int DaalaBitReader::ReadSymbol<16>(uint16_t* cdf);
 
 }  // namespace libgav1
diff --git a/libgav1/src/utils/entropy_decoder.h b/libgav1/src/utils/entropy_decoder.h
index b891bae..dfaf36a 100644
--- a/libgav1/src/utils/entropy_decoder.h
+++ b/libgav1/src/utils/entropy_decoder.h
@@ -30,16 +30,7 @@
   // on |symbol_count|. ReadSymbol calls for which the |symbol_count| is known
   // at compile time will use this variant.
   template <int symbol_count>
-  int ReadSymbol(uint16_t* const cdf) {
-    static_assert(symbol_count >= 3 && symbol_count <= 16, "");
-    const int symbol = (symbol_count <= 13)
-                           ? ReadSymbolImpl(cdf, symbol_count)
-                           : ReadSymbolImplBinarySearch(cdf, symbol_count);
-    if (allow_update_cdf_) {
-      UpdateCdf(cdf, symbol_count, symbol);
-    }
-    return symbol;
-  }
+  int ReadSymbol(uint16_t* cdf);
 
  private:
   using WindowSize = uint32_t;
@@ -66,7 +57,6 @@
   // Normalizes the range so that 32768 <= |values_in_range_| < 65536. Also
   // calls PopulateBits() if necessary.
   void NormalizeRange();
-  void UpdateCdf(uint16_t* cdf, int symbol_count, int symbol);
 
   const uint8_t* data_;
   const size_t size_;
diff --git a/libgav1/src/utils/raw_bit_reader.cc b/libgav1/src/utils/raw_bit_reader.cc
index 084535a..eb9c5c9 100644
--- a/libgav1/src/utils/raw_bit_reader.cc
+++ b/libgav1/src/utils/raw_bit_reader.cc
@@ -34,8 +34,7 @@
   assert(data_ != nullptr || size_ == 0);
 }
 
-int RawBitReader::ReadBit() {
-  if (Finished()) return -1;
+int RawBitReader::ReadBitImpl() {
   const size_t byte_offset = DivideBy8(bit_offset_, false);
   const uint8_t byte = data_[byte_offset];
   const uint8_t shift = 7 - Mod8(bit_offset_);
@@ -43,11 +42,19 @@
   return static_cast<int>((byte >> shift) & 0x01);
 }
 
+int RawBitReader::ReadBit() {
+  if (Finished()) return -1;
+  return ReadBitImpl();
+}
+
 int64_t RawBitReader::ReadLiteral(int num_bits) {
+  assert(num_bits <= 32);
   if (!CanReadLiteral(num_bits)) return -1;
   uint32_t value = 0;
+  // We can now call ReadBitImpl() since we've made sure that there are enough
+  // bits to be read.
   for (int i = num_bits - 1; i >= 0; --i) {
-    value |= static_cast<uint32_t>(ReadBit()) << i;
+    value |= static_cast<uint32_t>(ReadBitImpl()) << i;
   }
   return value;
 }
@@ -177,7 +184,6 @@
 
 bool RawBitReader::CanReadLiteral(size_t num_bits) const {
   if (Finished()) return false;
-  if (DivideBy8(num_bits, true) > sizeof(int)) return false;
   const size_t bit_offset = bit_offset_ + num_bits - 1;
   return DivideBy8(bit_offset, false) < size_;
 }
diff --git a/libgav1/src/utils/raw_bit_reader.h b/libgav1/src/utils/raw_bit_reader.h
index de2bc3a..b95ad27 100644
--- a/libgav1/src/utils/raw_bit_reader.h
+++ b/libgav1/src/utils/raw_bit_reader.h
@@ -50,6 +50,7 @@
  private:
   // Returns true if it is safe to read a literal of size |num_bits|.
   bool CanReadLiteral(size_t num_bits) const;
+  int ReadBitImpl();
 
   const uint8_t* const data_;
   size_t bit_offset_;
diff --git a/libgav1/src/utils/stack.h b/libgav1/src/utils/stack.h
index 59c6061..3d0cf4c 100644
--- a/libgav1/src/utils/stack.h
+++ b/libgav1/src/utils/stack.h
@@ -2,11 +2,12 @@
 #define LIBGAV1_SRC_UTILS_STACK_H_
 
 #include <cassert>
+#include <utility>
 
 namespace libgav1 {
 
-// A LIFO stack of a fixed capacity. The elements are copied, so the element
-// type T should be small.
+// A LIFO stack of a fixed capacity. The elements are moved using std::move, so
+// the element type T has to be movable.
 //
 // WARNING: No error checking is performed.
 template <typename T, int capacity>
@@ -17,14 +18,14 @@
   void Push(T value) {
     ++top_;
     assert(top_ < capacity);
-    elements_[top_] = value;
+    elements_[top_] = std::move(value);
   }
 
   // Returns the element at the top of the stack and removes it from the stack.
   // It is an error to call Pop() when the stack is empty.
   T Pop() {
     assert(top_ >= 0);
-    return elements_[top_--];
+    return std::move(elements_[top_--]);
   }
 
   // Returns true if the stack is empty.
diff --git a/libgav1/src/utils/threadpool.h b/libgav1/src/utils/threadpool.h
index 238bc44..5d7a369 100644
--- a/libgav1/src/utils/threadpool.h
+++ b/libgav1/src/utils/threadpool.h
@@ -109,8 +109,8 @@
 
 #else  // !LIBGAV1_THREADPOOL_USE_STD_MUTEX
 
-  void LockMutex() EXCLUSIVE_LOCK_FUNCTION() { queue_mutex_.Lock(); }
-  void UnlockMutex() UNLOCK_FUNCTION() { queue_mutex_.Unlock(); }
+  void LockMutex() ABSL_EXCLUSIVE_LOCK_FUNCTION() { queue_mutex_.Lock(); }
+  void UnlockMutex() ABSL_UNLOCK_FUNCTION() { queue_mutex_.Unlock(); }
   void Wait() { condition_.Wait(&queue_mutex_); }
   void SignalOne() { condition_.Signal(); }
   void SignalAll() { condition_.SignalAll(); }
diff --git a/libgav1/src/utils/vector.h b/libgav1/src/utils/vector.h
index f60a0a0..24ca9b3 100644
--- a/libgav1/src/utils/vector.h
+++ b/libgav1/src/utils/vector.h
@@ -82,7 +82,12 @@
       if (new_items == nullptr) return false;
       if (num_items_ > 0) {
         if (std::is_trivial<T>::value) {
-          memcpy(new_items, items_, num_items_ * sizeof(T));
+          // Cast |new_items| and |items_| to void* to avoid the GCC
+          // -Wclass-memaccess warning and additionally the
+          // bugprone-undefined-memory-manipulation clang-tidy warning. The
+          // memcpy is safe because T is a trivial type.
+          memcpy(static_cast<void*>(new_items),
+                 static_cast<const void*>(items_), num_items_ * sizeof(T));
         } else {
           for (size_t i = 0; i < num_items_; ++i) {
             new (&new_items[i]) T(std::move(items_[i]));
@@ -188,7 +193,12 @@
     for (iterator it = first; it != last; ++it) it->~T();
     if (last != end()) {
       if (std::is_trivial<T>::value) {
-        memmove(first, last, (end() - last) * sizeof(T));
+        // Cast |first| and |last| to void* to avoid the GCC
+        // -Wclass-memaccess warning and additionally the
+        // bugprone-undefined-memory-manipulation clang-tidy warning. The
+        // memmove is safe because T is a trivial type.
+        memmove(static_cast<void*>(first), static_cast<const void*>(last),
+                (end() - last) * sizeof(T));
       } else {
         for (iterator it_src = last, it_dst = first; it_src != end();
              ++it_src, ++it_dst) {